[ { "content": "import os", "chunk_type": "import", "name": "os", "file_path": "ultralytics\\docs\\build_docs.py", "start_line": 23, "end_line": 23, "start_col": 0, "end_col": 9, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_os_3144ec70" }, { "content": "import re", "chunk_type": "import", "name": "re", "file_path": "ultralytics\\docs\\build_docs.py", "start_line": 24, "end_line": 24, "start_col": 0, "end_col": 9, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_re_a63edcee" }, { "content": "import shutil", "chunk_type": "import", "name": "shutil", "file_path": "ultralytics\\docs\\build_docs.py", "start_line": 25, "end_line": 25, "start_col": 0, "end_col": 13, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_shutil_715ba7ca" }, { "content": "import subprocess", "chunk_type": "import", "name": "subprocess", "file_path": "ultralytics\\docs\\build_docs.py", "start_line": 26, "end_line": 26, "start_col": 0, "end_col": 17, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_subprocess_192ed8f7" }, { "content": "from pathlib import Path", "chunk_type": "import", "name": "Path", "file_path": "ultralytics\\docs\\build_docs.py", "start_line": 27, "end_line": 27, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Path_32252868" }, { "content": "from bs4 import BeautifulSoup", "chunk_type": "import", "name": "BeautifulSoup", "file_path": "ultralytics\\docs\\build_docs.py", "start_line": 29, "end_line": 29, "start_col": 0, "end_col": 29, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_BeautifulSoup_1a33e1ac" }, { "content": "from tqdm import tqdm", "chunk_type": "import", "name": "tqdm", "file_path": "ultralytics\\docs\\build_docs.py", "start_line": 30, "end_line": 30, "start_col": 0, "end_col": 21, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_tqdm_ac08201f" }, { "content": "DOCS = Path(__file__).parent.resolve()", "chunk_type": "variable", "name": "DOCS", "file_path": "ultralytics\\docs\\build_docs.py", "start_line": 33, "end_line": 33, "start_col": 0, "end_col": 38, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_DOCS_847a9cc8" }, { "content": "SITE = DOCS.parent / \"site\"", "chunk_type": "variable", "name": "SITE", "file_path": "ultralytics\\docs\\build_docs.py", "start_line": 34, "end_line": 34, "start_col": 0, "end_col": 27, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_SITE_63c9c393" }, { "content": "LINK_PATTERN = re.compile(r\"(https?://[^\\s()<>]*[^\\s()<>.,:;!?\\'\\\"])\")", "chunk_type": "variable", "name": "LINK_PATTERN", "file_path": "ultralytics\\docs\\build_docs.py", "start_line": 35, "end_line": 35, "start_col": 0, "end_col": 70, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_LINK_PATTERN_6896d2b0" }, { "content": "def prepare_docs_markdown(clone_repos: bool = True):\n \"\"\"Build docs using mkdocs.\"\"\"\n print(\"Removing existing build artifacts\")\n shutil.rmtree(SITE, ignore_errors=True)\n shutil.rmtree(DOCS / \"repos\", ignore_errors=True)\n\n if clone_repos:\n # Get hub-sdk repo\n repo = \"https://github.com/ultralytics/hub-sdk\"\n local_dir = DOCS / \"repos\" / Path(repo).name\n os.system(f\"git clone {repo} {local_dir} --depth 1 --single-branch --branch main\")\n shutil.rmtree(DOCS / \"en/hub/sdk\", ignore_errors=True) # delete if exists\n shutil.copytree(local_dir / \"docs\", DOCS / \"en/hub/sdk\") # for docs\n shutil.rmtree(DOCS.parent / \"hub_sdk\", ignore_errors=True) # delete if exists\n shutil.copytree(local_dir / \"hub_sdk\", DOCS.parent / \"hub_sdk\") # for mkdocstrings\n print(f\"Cloned/Updated {repo} in {local_dir}\")\n\n # Get docs repo\n repo = \"https://github.com/ultralytics/docs\"\n local_dir = DOCS / \"repos\" / Path(repo).name\n os.system(f\"git clone {repo} {local_dir} --depth 1 --single-branch --branch main\")\n shutil.rmtree(DOCS / \"en/compare\", ignore_errors=True) # delete if exists\n shutil.copytree(local_dir / \"docs/en/compare\", DOCS / \"en/compare\") # for docs\n print(f\"Cloned/Updated {repo} in {local_dir}\")\n\n # Add frontmatter\n for file in tqdm((DOCS / \"en\").rglob(\"*.md\"), desc=\"Adding frontmatter\"):\n update_markdown_files(file)", "chunk_type": "function", "name": "prepare_docs_markdown", "file_path": "ultralytics\\docs\\build_docs.py", "start_line": 38, "end_line": 65, "start_col": 0, "end_col": 35, "parent_name": null, "docstring": "Build docs using mkdocs.", "parameters": [ "clone_repos: bool" ], "return_type": null, "decorators": [], "complexity_score": 3, "dependencies": [ "os", "re", "shutil", "subprocess", "pathlib.Path", "bs4.BeautifulSoup", "tqdm.tqdm", "minify_html.minify", "csscompressor.compress", "jsmin" ], "chunk_id": "function_prepare_docs_markdown_db69cd29" }, { "content": "def update_page_title(file_path: Path, new_title: str):\n \"\"\"Update the title of an HTML file.\"\"\"\n with open(file_path, encoding=\"utf-8\") as file:\n content = file.read()\n\n # Replace the existing title with the new title\n updated_content = re.sub(r\".*?\", f\"{new_title}\", content)\n\n # Write the updated content back to the file\n with open(file_path, \"w\", encoding=\"utf-8\") as file:\n file.write(updated_content)", "chunk_type": "function", "name": "update_page_title", "file_path": "ultralytics\\docs\\build_docs.py", "start_line": 68, "end_line": 78, "start_col": 0, "end_col": 35, "parent_name": null, "docstring": "Update the title of an HTML file.", "parameters": [ "file_path: Path", "new_title: str" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "os", "re", "shutil", "subprocess", "pathlib.Path", "bs4.BeautifulSoup", "tqdm.tqdm", "minify_html.minify", "csscompressor.compress", "jsmin" ], "chunk_id": "function_update_page_title_939606c2" }, { "content": "def update_html_head(script: str = \"\"):\n \"\"\"Update the HTML head section of each file.\"\"\"\n html_files = Path(SITE).rglob(\"*.html\")\n for html_file in tqdm(html_files, desc=\"Processing HTML files\"):\n with html_file.open(\"r\", encoding=\"utf-8\") as file:\n html_content = file.read()\n\n if script in html_content: # script already in HTML file\n return\n\n head_end_index = html_content.lower().rfind(\"\")\n if head_end_index != -1:\n # Add the specified JavaScript to the HTML file just before the end of the head tag.\n new_html_content = html_content[:head_end_index] + script + html_content[head_end_index:]\n with html_file.open(\"w\", encoding=\"utf-8\") as file:\n file.write(new_html_content)", "chunk_type": "function", "name": "update_html_head", "file_path": "ultralytics\\docs\\build_docs.py", "start_line": 81, "end_line": 96, "start_col": 0, "end_col": 44, "parent_name": null, "docstring": "Update the HTML head section of each file.", "parameters": [ "script: str" ], "return_type": null, "decorators": [], "complexity_score": 4, "dependencies": [ "os", "re", "shutil", "subprocess", "pathlib.Path", "bs4.BeautifulSoup", "tqdm.tqdm", "minify_html.minify", "csscompressor.compress", "jsmin" ], "chunk_id": "function_update_html_head_8f15dff9" }, { "content": "def update_subdir_edit_links(subdir: str = \"\", docs_url: str = \"\"):\n \"\"\"Update the HTML head section of each file.\"\"\"\n if str(subdir[0]) == \"/\":\n subdir = str(subdir[0])[1:]\n html_files = (SITE / subdir).rglob(\"*.html\")\n for html_file in tqdm(html_files, desc=\"Processing subdir files\", mininterval=1.0):\n with html_file.open(\"r\", encoding=\"utf-8\") as file:\n soup = BeautifulSoup(file, \"html.parser\")\n\n # Find the anchor tag and update its href attribute\n a_tag = soup.find(\"a\", {\"class\": \"md-content__button md-icon\"})\n if a_tag and a_tag[\"title\"] == \"Edit this page\":\n a_tag[\"href\"] = f\"{docs_url}{a_tag['href'].rpartition(subdir)[-1]}\"\n\n # Write the updated HTML back to the file\n with open(html_file, \"w\", encoding=\"utf-8\") as file:\n file.write(str(soup))", "chunk_type": "function", "name": "update_subdir_edit_links", "file_path": "ultralytics\\docs\\build_docs.py", "start_line": 99, "end_line": 115, "start_col": 0, "end_col": 33, "parent_name": null, "docstring": "Update the HTML head section of each file.", "parameters": [ "subdir: str", "docs_url: str" ], "return_type": null, "decorators": [], "complexity_score": 4, "dependencies": [ "os", "re", "shutil", "subprocess", "pathlib.Path", "bs4.BeautifulSoup", "tqdm.tqdm", "minify_html.minify", "csscompressor.compress", "jsmin" ], "chunk_id": "function_update_subdir_edit_links_2605cb6a" }, { "content": "def update_markdown_files(md_filepath: Path):\n \"\"\"Create or update a Markdown file, ensuring frontmatter is present.\"\"\"\n if md_filepath.exists():\n content = md_filepath.read_text().strip()\n\n # Replace apostrophes\n content = content.replace(\"‘\", \"'\").replace(\"’\", \"'\")\n\n # Add frontmatter if missing\n if not content.strip().startswith(\"---\\n\") and \"macros\" not in md_filepath.parts: # skip macros directory\n header = \"---\\ncomments: true\\ndescription: TODO ADD DESCRIPTION\\nkeywords: TODO ADD KEYWORDS\\n---\\n\\n\"\n content = header + content\n\n # Ensure MkDocs admonitions \"=== \" lines are preceded and followed by empty newlines\n lines = content.split(\"\\n\")\n new_lines = []\n for i, line in enumerate(lines):\n stripped_line = line.strip()\n if stripped_line.startswith(\"=== \"):\n if i > 0 and new_lines[-1] != \"\":\n new_lines.append(\"\")\n new_lines.append(line)\n if i < len(lines) - 1 and lines[i + 1].strip() != \"\":\n new_lines.append(\"\")\n else:\n new_lines.append(line)\n content = \"\\n\".join(new_lines)\n\n # Add EOF newline if missing\n if not content.endswith(\"\\n\"):\n content += \"\\n\"\n\n # Save page\n md_filepath.write_text(content)\n return", "chunk_type": "function", "name": "update_markdown_files", "file_path": "ultralytics\\docs\\build_docs.py", "start_line": 118, "end_line": 152, "start_col": 0, "end_col": 10, "parent_name": null, "docstring": "Create or update a Markdown file, ensuring frontmatter is present.", "parameters": [ "md_filepath: Path" ], "return_type": null, "decorators": [], "complexity_score": 8, "dependencies": [ "os", "re", "shutil", "subprocess", "pathlib.Path", "bs4.BeautifulSoup", "tqdm.tqdm", "minify_html.minify", "csscompressor.compress", "jsmin" ], "chunk_id": "function_update_markdown_files_52904885" }, { "content": "def update_docs_html():\n \"\"\"Update titles, edit links, head sections, and convert plaintext links in HTML documentation.\"\"\"\n # Update 404 titles\n update_page_title(SITE / \"404.html\", new_title=\"Ultralytics Docs - Not Found\")\n\n # Update edit button links\n for subdir, docs_url in (\n (\"hub/sdk/\", \"https://github.com/ultralytics/hub-sdk/tree/main/docs/\"), # do not use leading slash\n (\"compare/\", \"https://github.com/ultralytics/docs/tree/main/docs/en/compare/\"),\n ):\n update_subdir_edit_links(subdir=subdir, docs_url=docs_url)\n\n # Convert plaintext links to HTML hyperlinks\n files_modified = 0\n for html_file in tqdm(SITE.rglob(\"*.html\"), desc=\"Updating bs4 soup\", mininterval=1.0):\n with open(html_file, encoding=\"utf-8\") as file:\n content = file.read()\n updated_content = update_docs_soup(content, html_file=html_file)\n if updated_content != content:\n with open(html_file, \"w\", encoding=\"utf-8\") as file:\n file.write(updated_content)\n files_modified += 1\n print(f\"Modified bs4 soup in {files_modified} files.\")\n\n # Update HTML file head section\n script = \"\"\n if any(script):\n update_html_head(script)\n\n # Delete the /macros directory from the built site\n macros_dir = SITE / \"macros\"\n if macros_dir.exists():\n print(f\"Removing /macros directory from site: {macros_dir}\")\n shutil.rmtree(macros_dir)", "chunk_type": "function", "name": "update_docs_html", "file_path": "ultralytics\\docs\\build_docs.py", "start_line": 155, "end_line": 188, "start_col": 0, "end_col": 33, "parent_name": null, "docstring": "Update titles, edit links, head sections, and convert plaintext links in HTML documentation.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 6, "dependencies": [ "os", "re", "shutil", "subprocess", "pathlib.Path", "bs4.BeautifulSoup", "tqdm.tqdm", "minify_html.minify", "csscompressor.compress", "jsmin" ], "chunk_id": "function_update_docs_html_211fe039" }, { "content": "def update_docs_soup(content: str, html_file: Path = None, max_title_length: int = 70) -> str:\n \"\"\"Convert plaintext links to HTML hyperlinks, truncate long meta titles, and remove code line hrefs.\"\"\"\n soup = BeautifulSoup(content, \"html.parser\")\n modified = False\n\n # Truncate long meta title if needed\n title_tag = soup.find(\"title\")\n if title_tag and len(title_tag.text) > max_title_length and \"-\" in title_tag.text:\n title_tag.string = title_tag.text.rsplit(\"-\", 1)[0].strip()\n modified = True\n\n # Find the main content area\n main_content = soup.find(\"main\") or soup.find(\"div\", class_=\"md-content\")\n if not main_content:\n return str(soup) if modified else content\n\n # Convert plaintext links to HTML hyperlinks\n for paragraph in main_content.select(\"p, li\"):\n for text_node in paragraph.find_all(string=True, recursive=False):\n if text_node.parent.name not in {\"a\", \"code\"}:\n new_text = LINK_PATTERN.sub(r'\\1', str(text_node))\n if \" str:\n \"\"\"\n Remove comments and empty lines from a string of code, preserving newlines and URLs.\n\n Args:\n content (str): Code content to process.\n file_type (str): Type of file ('html', 'css', or 'js').\n\n Returns:\n (str): Cleaned content with comments and empty lines removed.\n\n Notes:\n Typical reductions for Ultralytics Docs are:\n - Total HTML reduction: 2.83% (1301.56 KB saved)\n - Total CSS reduction: 1.75% (2.61 KB saved)\n - Total JS reduction: 13.51% (99.31 KB saved)\n \"\"\"\n if file_type == \"html\":\n # Remove HTML comments\n content = re.sub(r\"\", \"\", content)\n # Only remove empty lines for HTML, preserve indentation\n content = re.sub(r\"^\\s*$\\n\", \"\", content, flags=re.MULTILINE)\n elif file_type == \"css\":\n # Remove CSS comments\n content = re.sub(r\"/\\*[\\s\\S]*?\\*/\", \"\", content)\n # Remove whitespace around specific characters\n content = re.sub(r\"\\s*([{}:;,])\\s*\", r\"\\1\", content)\n # Remove empty lines\n content = re.sub(r\"^\\s*\\n\", \"\", content, flags=re.MULTILINE)\n # Collapse multiple spaces to single space\n content = re.sub(r\"\\s{2,}\", \" \", content)\n # Remove all newlines\n content = re.sub(r\"\\n\", \"\", content)\n elif file_type == \"js\":\n # Handle JS single-line comments (preserving http:// and https://)\n lines = content.split(\"\\n\")\n processed_lines = []\n for line in lines:\n # Only remove comments if they're not part of a URL\n if \"//\" in line and \"http://\" not in line and \"https://\" not in line:\n processed_lines.append(line.partition(\"//\")[0])\n else:\n processed_lines.append(line)\n content = \"\\n\".join(processed_lines)\n\n # Remove JS multi-line comments and clean whitespace\n content = re.sub(r\"/\\*[\\s\\S]*?\\*/\", \"\", content)\n # Remove empty lines\n content = re.sub(r\"^\\s*\\n\", \"\", content, flags=re.MULTILINE)\n # Collapse multiple spaces to single space\n content = re.sub(r\"\\s{2,}\", \" \", content)\n\n # Safe space removal around punctuation and operators (NEVER include colons - breaks JS)\n content = re.sub(r\"\\s*([,;{}])\\s*\", r\"\\1\", content)\n content = re.sub(r\"(\\w)\\s*\\(|\\)\\s*{|\\s*([+\\-*/=])\\s*\", lambda m: m.group(0).replace(\" \", \"\"), content)\n\n return content", "chunk_type": "function", "name": "remove_comments_and_empty_lines", "file_path": "ultralytics\\docs\\build_docs.py", "start_line": 256, "end_line": 312, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": "Remove comments and empty lines from a string of code, preserving newlines and URLs.\n\nArgs:\n content (str): Code content to process.\n file_type (str): Type of file ('html', 'css', or 'js').\n\nReturns:\n (str): Cleaned content with comments and empty lines removed.\n\nNotes:\n Typical reductions for Ultralytics Docs are:\n - Total HTML reduction: 2.83% (1301.56 KB saved)\n - Total CSS reduction: 1.75% (2.61 KB saved)\n - Total JS reduction: 13.51% (99.31 KB saved)", "parameters": [ "content: str", "file_type: str" ], "return_type": "str", "decorators": [], "complexity_score": 6, "dependencies": [ "os", "re", "shutil", "subprocess", "pathlib.Path", "bs4.BeautifulSoup", "tqdm.tqdm", "minify_html.minify", "csscompressor.compress", "jsmin" ], "chunk_id": "function_remove_comments_and_empty_lines_67c4d49f" }, { "content": "def minify_files(html: bool = True, css: bool = True, js: bool = True):\n \"\"\"Minify HTML, CSS, and JS files and print total reduction stats.\"\"\"\n minify, compress, jsmin = None, None, None\n try:\n if html:\n from minify_html import minify\n if css:\n from csscompressor import compress\n if js:\n import jsmin\n except ImportError as e:\n print(f\"Missing required package: {str(e)}\")\n return\n\n stats = {}\n for ext, minifier in {\n \"html\": (lambda x: minify(x, keep_closing_tags=True, minify_css=True, minify_js=True)) if html else None,\n \"css\": compress if css else None,\n \"js\": jsmin.jsmin if js else None,\n }.items():\n stats[ext] = {\"original\": 0, \"minified\": 0}\n directory = \"\" # \"stylesheets\" if ext == css else \"javascript\" if ext == \"js\" else \"\"\n for f in tqdm((SITE / directory).rglob(f\"*.{ext}\"), desc=f\"Minifying {ext.upper()}\", mininterval=1.0):\n content = f.read_text(encoding=\"utf-8\")\n minified = minifier(content) if minifier else remove_comments_and_empty_lines(content, ext)\n stats[ext][\"original\"] += len(content)\n stats[ext][\"minified\"] += len(minified)\n f.write_text(minified, encoding=\"utf-8\")\n\n for ext, data in stats.items():\n if data[\"original\"]:\n r = data[\"original\"] - data[\"minified\"] # reduction\n print(f\"Total {ext.upper()} reduction: {(r / data['original']) * 100:.2f}% ({r / 1024:.2f} KB saved)\")", "chunk_type": "function", "name": "minify_files", "file_path": "ultralytics\\docs\\build_docs.py", "start_line": 315, "end_line": 347, "start_col": 0, "end_col": 114, "parent_name": null, "docstring": "Minify HTML, CSS, and JS files and print total reduction stats.", "parameters": [ "html: bool", "css: bool", "js: bool" ], "return_type": null, "decorators": [], "complexity_score": 9, "dependencies": [ "os", "re", "shutil", "subprocess", "pathlib.Path", "bs4.BeautifulSoup", "tqdm.tqdm", "minify_html.minify", "csscompressor.compress", "jsmin" ], "chunk_id": "function_minify_files_46ed1bc2" }, { "content": "def main():\n \"\"\"Build docs, update titles and edit links, minify HTML, and print local server command.\"\"\"\n prepare_docs_markdown()\n\n # Build the main documentation\n print(f\"Building docs from {DOCS}\")\n subprocess.run(f\"mkdocs build -f {DOCS.parent}/mkdocs.yml --strict\", check=True, shell=True)\n remove_macros()\n print(f\"Site built at {SITE}\")\n\n # Update docs HTML pages\n update_docs_html()\n\n # Minify files\n minify_files(html=False, css=False, js=False)\n\n # Cleanup\n shutil.rmtree(DOCS.parent / \"hub_sdk\", ignore_errors=True)\n shutil.rmtree(DOCS / \"repos\", ignore_errors=True)\n\n # Print results\n size = sum(f.stat().st_size for f in SITE.rglob(\"*\") if f.is_file()) >> 20\n print(\n f\"Docs built correctly ✅ ({size:.1f} MB)\\n\"\n f'Serve site at http://localhost:8000 with \"python -m http.server --directory site\"'\n )", "chunk_type": "function", "name": "main", "file_path": "ultralytics\\docs\\build_docs.py", "start_line": 350, "end_line": 375, "start_col": 0, "end_col": 5, "parent_name": null, "docstring": "Build docs, update titles and edit links, minify HTML, and print local server command.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 2, "dependencies": [ "os", "re", "shutil", "subprocess", "pathlib.Path", "bs4.BeautifulSoup", "tqdm.tqdm", "minify_html.minify", "csscompressor.compress", "jsmin" ], "chunk_id": "function_main_9ed512b0" }, { "content": "import re", "chunk_type": "import", "name": "re", "file_path": "ultralytics\\docs\\build_reference.py", "start_line": 11, "end_line": 11, "start_col": 0, "end_col": 9, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_re_11b93c45" }, { "content": "import subprocess", "chunk_type": "import", "name": "subprocess", "file_path": "ultralytics\\docs\\build_reference.py", "start_line": 12, "end_line": 12, "start_col": 0, "end_col": 17, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_subprocess_05afc4b6" }, { "content": "from collections import defaultdict", "chunk_type": "import", "name": "defaultdict", "file_path": "ultralytics\\docs\\build_reference.py", "start_line": 13, "end_line": 13, "start_col": 0, "end_col": 35, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_defaultdict_322329e1" }, { "content": "from pathlib import Path", "chunk_type": "import", "name": "Path", "file_path": "ultralytics\\docs\\build_reference.py", "start_line": 14, "end_line": 14, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Path_2091f3b6" }, { "content": "hub_sdk = False", "chunk_type": "variable", "name": "hub_sdk", "file_path": "ultralytics\\docs\\build_reference.py", "start_line": 17, "end_line": 17, "start_col": 0, "end_col": 15, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_hub_sdk_de71f91e" }, { "content": "MKDOCS_YAML = PACKAGE_DIR.parent / \"mkdocs.yml\"", "chunk_type": "variable", "name": "MKDOCS_YAML", "file_path": "ultralytics\\docs\\build_reference.py", "start_line": 28, "end_line": 28, "start_col": 0, "end_col": 47, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_MKDOCS_YAML_41d7c530" }, { "content": "def extract_classes_and_functions(filepath: Path) -> tuple[list[str], list[str]]:\n \"\"\"Extract class and function names from a given Python file.\"\"\"\n content = filepath.read_text()\n return (re.findall(r\"(?:^|\\n)class\\s(\\w+)(?:\\(|:)\", content), re.findall(r\"(?:^|\\n)def\\s(\\w+)\\(\", content))", "chunk_type": "function", "name": "extract_classes_and_functions", "file_path": "ultralytics\\docs\\build_reference.py", "start_line": 31, "end_line": 34, "start_col": 0, "end_col": 111, "parent_name": null, "docstring": "Extract class and function names from a given Python file.", "parameters": [ "filepath: Path" ], "return_type": "tuple[list[str], list[str]]", "decorators": [], "complexity_score": 1, "dependencies": [ "re", "subprocess", "collections.defaultdict", "pathlib.Path" ], "chunk_id": "function_extract_classes_and_functions_27cbdcbe" }, { "content": "def create_markdown(py_filepath: Path, module_path: str, classes: list[str], functions: list[str]) -> Path:\n \"\"\"Create a Markdown file containing the API reference for the given Python module.\"\"\"\n md_filepath = py_filepath.with_suffix(\".md\")\n exists = md_filepath.exists()\n\n # Read existing content and retain header metadata if available\n header_content = \"\"\n if exists:\n existing_content = md_filepath.read_text()\n header_parts = existing_content.split(\"---\")\n for part in header_parts:\n if \"description:\" in part or \"comments:\" in part:\n header_content += f\"---{part}---\\n\\n\"\n if not any(header_content):\n header_content = \"---\\ndescription: TODO ADD DESCRIPTION\\nkeywords: TODO ADD KEYWORDS\\n---\\n\\n\"\n\n module_name = module_path.replace(\".__init__\", \"\")\n module_path = module_path.replace(\".\", \"/\")\n url = f\"https://github.com/{GITHUB_REPO}/blob/main/{module_path}.py\"\n edit = f\"https://github.com/{GITHUB_REPO}/edit/main/{module_path}.py\"\n pretty = url.replace(\"__init__.py\", \"\\\\_\\\\_init\\\\_\\\\_.py\") # Properly display __init__.py filenames\n\n # Build markdown content\n title_content = (\n f\"# Reference for `{module_path}.py`\\n\\n\"\n f\"!!! note\\n\\n\"\n f\" This file is available at [{pretty}]({url}). If you spot a problem please help fix it by [contributing]\"\n f\"(https://docs.ultralytics.com/help/contributing/) a [Pull Request]({edit}) 🛠️. Thank you 🙏!\\n\\n\"\n )\n md_content = [\"
\\n\\n\"]\n md_content.extend(f\"## ::: {module_name}.{cls}\\n\\n



\\n\\n\" for cls in classes)\n md_content.extend(f\"## ::: {module_name}.{func}\\n\\n



\\n\\n\" for func in functions)\n if md_content[-1:]: # Remove last horizontal rule if content exists\n md_content[-1] = md_content[-1].replace(\"

\\n\\n\", \"\")\n\n # Write to file\n md_filepath.parent.mkdir(parents=True, exist_ok=True)\n md_filepath.write_text(header_content + title_content + \"\".join(md_content) + \"\\n\")\n\n if not exists:\n print(f\"Created new file '{md_filepath}'\")\n subprocess.run([\"git\", \"add\", \"-f\", str(md_filepath)], check=True, cwd=PACKAGE_DIR)\n\n return md_filepath.relative_to(PACKAGE_DIR.parent)", "chunk_type": "function", "name": "create_markdown", "file_path": "ultralytics\\docs\\build_reference.py", "start_line": 37, "end_line": 80, "start_col": 0, "end_col": 54, "parent_name": null, "docstring": "Create a Markdown file containing the API reference for the given Python module.", "parameters": [ "py_filepath: Path", "module_path: str", "classes: list[str]", "functions: list[str]" ], "return_type": "Path", "decorators": [], "complexity_score": 9, "dependencies": [ "re", "subprocess", "collections.defaultdict", "pathlib.Path" ], "chunk_id": "function_create_markdown_c4fd72d4" }, { "content": "def nested_dict():\n \"\"\"Create and return a nested defaultdict.\"\"\"\n return defaultdict(nested_dict)", "chunk_type": "function", "name": "nested_dict", "file_path": "ultralytics\\docs\\build_reference.py", "start_line": 83, "end_line": 85, "start_col": 0, "end_col": 35, "parent_name": null, "docstring": "Create and return a nested defaultdict.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "re", "subprocess", "collections.defaultdict", "pathlib.Path" ], "chunk_id": "function_nested_dict_271b7e37" }, { "content": "def sort_nested_dict(d: dict) -> dict:\n \"\"\"Sort a nested dictionary recursively.\"\"\"\n return {k: sort_nested_dict(v) if isinstance(v, dict) else v for k, v in sorted(d.items())}", "chunk_type": "function", "name": "sort_nested_dict", "file_path": "ultralytics\\docs\\build_reference.py", "start_line": 88, "end_line": 90, "start_col": 0, "end_col": 95, "parent_name": null, "docstring": "Sort a nested dictionary recursively.", "parameters": [ "d: dict" ], "return_type": "dict", "decorators": [], "complexity_score": 2, "dependencies": [ "re", "subprocess", "collections.defaultdict", "pathlib.Path" ], "chunk_id": "function_sort_nested_dict_25d4958f" }, { "content": "def create_nav_menu_yaml(nav_items: list[str]) -> str:\n \"\"\"Create and return a YAML string for the navigation menu.\"\"\"\n nav_tree = nested_dict()\n\n for item_str in nav_items:\n item = Path(item_str)\n parts = item.parts\n current_level = nav_tree[\"reference\"]\n for part in parts[2:-1]: # Skip docs/reference and filename\n current_level = current_level[part]\n current_level[parts[-1].replace(\".md\", \"\")] = item\n\n def _dict_to_yaml(d, level=0):\n \"\"\"Convert a nested dictionary to a YAML-formatted string with indentation.\"\"\"\n yaml_str = \"\"\n indent = \" \" * level\n for k, v in sorted(d.items()):\n if isinstance(v, dict):\n yaml_str += f\"{indent}- {k}:\\n{_dict_to_yaml(v, level + 1)}\"\n else:\n yaml_str += f\"{indent}- {k}: {str(v).replace('docs/en/', '')}\\n\"\n return yaml_str\n\n reference_yaml = _dict_to_yaml(sort_nested_dict(nav_tree))\n print(f\"Scan complete, generated reference section with {len(reference_yaml.splitlines())} lines\")\n return reference_yaml", "chunk_type": "function", "name": "create_nav_menu_yaml", "file_path": "ultralytics\\docs\\build_reference.py", "start_line": 93, "end_line": 118, "start_col": 0, "end_col": 25, "parent_name": null, "docstring": "Create and return a YAML string for the navigation menu.", "parameters": [ "nav_items: list[str]" ], "return_type": "str", "decorators": [], "complexity_score": 5, "dependencies": [ "re", "subprocess", "collections.defaultdict", "pathlib.Path" ], "chunk_id": "function_create_nav_menu_yaml_8cf55b3f" }, { "content": "def extract_document_paths(yaml_section: str) -> list[str]:\n \"\"\"Extract document paths from a YAML section, ignoring formatting and structure.\"\"\"\n paths = []\n # Match all paths that appear after a colon in the YAML\n path_matches = re.findall(r\":\\s*([^\\s][^:\\n]*?)(?:\\n|$)\", yaml_section)\n for path in path_matches:\n # Clean up the path\n path = path.strip()\n if path and not path.startswith(\"-\") and not path.endswith(\":\"):\n paths.append(path)\n return sorted(paths)", "chunk_type": "function", "name": "extract_document_paths", "file_path": "ultralytics\\docs\\build_reference.py", "start_line": 121, "end_line": 131, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": "Extract document paths from a YAML section, ignoring formatting and structure.", "parameters": [ "yaml_section: str" ], "return_type": "list[str]", "decorators": [], "complexity_score": 3, "dependencies": [ "re", "subprocess", "collections.defaultdict", "pathlib.Path" ], "chunk_id": "function_extract_document_paths_97a5461c" }, { "content": "def update_mkdocs_file(reference_yaml: str) -> None:\n \"\"\"Update the mkdocs.yaml file with the new reference section only if changes in document paths are detected.\"\"\"\n mkdocs_content = MKDOCS_YAML.read_text()\n\n # Find the top-level Reference section\n ref_pattern = r\"(\\n - Reference:[\\s\\S]*?)(?=\\n - \\w|$)\"\n ref_match = re.search(ref_pattern, mkdocs_content)\n\n # Build new section with proper indentation\n new_section_lines = [\"\\n - Reference:\"]\n for line in reference_yaml.splitlines():\n if line.strip() == \"- reference:\": # Skip redundant header\n continue\n new_section_lines.append(f\" {line}\")\n new_ref_section = \"\\n\".join(new_section_lines) + \"\\n\"\n\n if ref_match:\n # We found an existing Reference section\n ref_section = ref_match.group(1)\n print(f\"Found existing top-level Reference section ({len(ref_section)} chars)\")\n\n # Compare only document paths\n existing_paths = extract_document_paths(ref_section)\n new_paths = extract_document_paths(new_ref_section)\n\n # Check if the document paths are the same (ignoring structure or formatting differences)\n if len(existing_paths) == len(new_paths) and set(existing_paths) == set(new_paths):\n print(f\"No changes detected in document paths ({len(existing_paths)} items). Skipping update.\")\n return\n\n print(f\"Changes detected: {len(new_paths)} document paths vs {len(existing_paths)} existing\")\n\n # Update content\n new_content = mkdocs_content.replace(ref_section, new_ref_section)\n MKDOCS_YAML.write_text(new_content)\n subprocess.run([\"npx\", \"prettier\", \"--write\", str(MKDOCS_YAML)], check=False, cwd=PACKAGE_DIR.parent)\n print(f\"Updated Reference section in {MKDOCS_YAML}\")\n else:\n # No existing Reference section, we need to add it\n help_match = re.search(r\"(\\n - Help:)\", mkdocs_content)\n if help_match:\n help_section = help_match.group(1)\n # Insert before Help section\n new_content = mkdocs_content.replace(help_section, f\"{new_ref_section}{help_section}\")\n MKDOCS_YAML.write_text(new_content)\n print(f\"Added new Reference section before Help in {MKDOCS_YAML}\")\n else:\n print(\"Could not find a suitable location to add Reference section\")", "chunk_type": "function", "name": "update_mkdocs_file", "file_path": "ultralytics\\docs\\build_reference.py", "start_line": 134, "end_line": 181, "start_col": 0, "end_col": 80, "parent_name": null, "docstring": "Update the mkdocs.yaml file with the new reference section only if changes in document paths are detected.", "parameters": [ "reference_yaml: str" ], "return_type": "None", "decorators": [], "complexity_score": 6, "dependencies": [ "re", "subprocess", "collections.defaultdict", "pathlib.Path" ], "chunk_id": "function_update_mkdocs_file_debcff5b" }, { "content": "def main():\n \"\"\"Extract class/function names, create Markdown files, and update mkdocs.yaml.\"\"\"\n nav_items = []\n\n for py_filepath in PACKAGE_DIR.rglob(\"*.py\"):\n classes, functions = extract_classes_and_functions(py_filepath)\n if classes or functions:\n py_filepath_rel = py_filepath.relative_to(PACKAGE_DIR)\n md_filepath = REFERENCE_DIR / py_filepath_rel\n module_path = f\"{PACKAGE_DIR.name}.{py_filepath_rel.with_suffix('').as_posix().replace('/', '.')}\"\n md_rel_filepath = create_markdown(md_filepath, module_path, classes, functions)\n nav_items.append(str(md_rel_filepath))\n\n # Update mkdocs.yaml with generated YAML\n update_mkdocs_file(create_nav_menu_yaml(nav_items))", "chunk_type": "function", "name": "main", "file_path": "ultralytics\\docs\\build_reference.py", "start_line": 184, "end_line": 198, "start_col": 0, "end_col": 55, "parent_name": null, "docstring": "Extract class/function names, create Markdown files, and update mkdocs.yaml.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 3, "dependencies": [ "re", "subprocess", "collections.defaultdict", "pathlib.Path" ], "chunk_id": "function_main_246f987e" }, { "content": "data = {\n \"YOLO12\": {\n \"author\": \"Yunjie Tian, Qixiang Ye, and David Doermann\",\n \"org\": \"University at Buffalo and University of Chinese Academy of Sciences\",\n \"date\": \"2025-02-18\",\n \"arxiv\": \"https://arxiv.org/abs/2502.12524\",\n \"github\": \"https://github.com/sunsmarterjie/yolov12\",\n \"docs\": \"https://docs.ultralytics.com/models/yolo12/\",\n \"performance\": {\n \"n\": {\"size\": 640, \"map\": 40.6, \"cpu\": \"\", \"t4\": 1.64, \"params\": 2.6, \"flops\": 6.5},\n \"s\": {\"size\": 640, \"map\": 48.0, \"cpu\": \"\", \"t4\": 2.61, \"params\": 9.3, \"flops\": 21.4},\n \"m\": {\"size\": 640, \"map\": 52.5, \"cpu\": \"\", \"t4\": 4.86, \"params\": 20.2, \"flops\": 67.5},\n \"l\": {\"size\": 640, \"map\": 53.7, \"cpu\": \"\", \"t4\": 6.77, \"params\": 26.4, \"flops\": 88.9},\n \"x\": {\"size\": 640, \"map\": 55.2, \"cpu\": \"\", \"t4\": 11.79, \"params\": 59.1, \"flops\": 199.0},\n },\n },\n \"YOLO11\": {\n \"author\": \"Glenn Jocher and Jing Qiu\",\n \"org\": \"Ultralytics\",\n \"date\": \"2024-09-27\",\n \"arxiv\": None,\n \"github\": \"https://github.com/ultralytics/ultralytics\",\n \"docs\": \"https://docs.ultralytics.com/models/yolo11/\",\n \"performance\": {\n \"n\": {\"size\": 640, \"map\": 39.5, \"cpu\": 56.1, \"t4\": 1.5, \"params\": 2.6, \"flops\": 6.5},\n \"s\": {\"size\": 640, \"map\": 47.0, \"cpu\": 90.0, \"t4\": 2.5, \"params\": 9.4, \"flops\": 21.5},\n \"m\": {\"size\": 640, \"map\": 51.5, \"cpu\": 183.2, \"t4\": 4.7, \"params\": 20.1, \"flops\": 68.0},\n \"l\": {\"size\": 640, \"map\": 53.4, \"cpu\": 238.6, \"t4\": 6.2, \"params\": 25.3, \"flops\": 86.9},\n \"x\": {\"size\": 640, \"map\": 54.7, \"cpu\": 462.8, \"t4\": 11.3, \"params\": 56.9, \"flops\": 194.9},\n },\n },\n \"YOLOv10\": {\n \"author\": \"Ao Wang, Hui Chen, Lihao Liu, et al.\",\n \"org\": \"Tsinghua University\",\n \"date\": \"2024-05-23\",\n \"arxiv\": \"https://arxiv.org/abs/2405.14458\",\n \"github\": \"https://github.com/THU-MIG/yolov10\",\n \"docs\": \"https://docs.ultralytics.com/models/yolov10/\",\n \"performance\": {\n \"n\": {\"size\": 640, \"map\": 39.5, \"cpu\": \"\", \"t4\": 1.56, \"params\": 2.3, \"flops\": 6.7},\n \"s\": {\"size\": 640, \"map\": 46.7, \"cpu\": \"\", \"t4\": 2.66, \"params\": 7.2, \"flops\": 21.6},\n \"m\": {\"size\": 640, \"map\": 51.3, \"cpu\": \"\", \"t4\": 5.48, \"params\": 15.4, \"flops\": 59.1},\n \"b\": {\"size\": 640, \"map\": 52.7, \"cpu\": \"\", \"t4\": 6.54, \"params\": 24.4, \"flops\": 92.0},\n \"l\": {\"size\": 640, \"map\": 53.3, \"cpu\": \"\", \"t4\": 8.33, \"params\": 29.5, \"flops\": 120.3},\n \"x\": {\"size\": 640, \"map\": 54.4, \"cpu\": \"\", \"t4\": 12.2, \"params\": 56.9, \"flops\": 160.4},\n },\n },\n \"YOLOv9\": {\n \"author\": \"Chien-Yao Wang and Hong-Yuan Mark Liao\",\n \"org\": \"Institute of Information Science, Academia Sinica, Taiwan\",\n \"date\": \"2024-02-21\",\n \"arxiv\": \"https://arxiv.org/abs/2402.13616\",\n \"github\": \"https://github.com/WongKinYiu/yolov9\",\n \"docs\": \"https://docs.ultralytics.com/models/yolov9/\",\n \"performance\": {\n \"t\": {\"size\": 640, \"map\": 38.3, \"cpu\": \"\", \"t4\": 2.3, \"params\": 2.0, \"flops\": 7.7},\n \"s\": {\"size\": 640, \"map\": 46.8, \"cpu\": \"\", \"t4\": 3.54, \"params\": 7.1, \"flops\": 26.4},\n \"m\": {\"size\": 640, \"map\": 51.4, \"cpu\": \"\", \"t4\": 6.43, \"params\": 20.0, \"flops\": 76.3},\n \"c\": {\"size\": 640, \"map\": 53.0, \"cpu\": \"\", \"t4\": 7.16, \"params\": 25.3, \"flops\": 102.1},\n \"e\": {\"size\": 640, \"map\": 55.6, \"cpu\": \"\", \"t4\": 16.77, \"params\": 57.3, \"flops\": 189.0},\n },\n },\n \"YOLOv8\": {\n \"author\": \"Glenn Jocher, Ayush Chaurasia, and Jing Qiu\",\n \"org\": \"Ultralytics\",\n \"date\": \"2023-01-10\",\n \"arxiv\": None,\n \"github\": \"https://github.com/ultralytics/ultralytics\",\n \"docs\": \"https://docs.ultralytics.com/models/yolov8/\",\n \"performance\": {\n \"n\": {\"size\": 640, \"map\": 37.3, \"cpu\": 80.4, \"t4\": 1.47, \"params\": 3.2, \"flops\": 8.7},\n \"s\": {\"size\": 640, \"map\": 44.9, \"cpu\": 128.4, \"t4\": 2.66, \"params\": 11.2, \"flops\": 28.6},\n \"m\": {\"size\": 640, \"map\": 50.2, \"cpu\": 234.7, \"t4\": 5.86, \"params\": 25.9, \"flops\": 78.9},\n \"l\": {\"size\": 640, \"map\": 52.9, \"cpu\": 375.2, \"t4\": 9.06, \"params\": 43.7, \"flops\": 165.2},\n \"x\": {\"size\": 640, \"map\": 53.9, \"cpu\": 479.1, \"t4\": 14.37, \"params\": 68.2, \"flops\": 257.8},\n },\n },\n \"YOLOv7\": {\n \"author\": \"Chien-Yao Wang, Alexey Bochkovskiy, and Hong-Yuan Mark Liao\",\n \"org\": \"Institute of Information Science, Academia Sinica, Taiwan\",\n \"date\": \"2022-07-06\",\n \"arxiv\": \"https://arxiv.org/abs/2207.02696\",\n \"github\": \"https://github.com/WongKinYiu/yolov7\",\n \"docs\": \"https://docs.ultralytics.com/models/yolov7/\",\n \"performance\": {\n \"l\": {\"size\": 640, \"map\": 51.4, \"cpu\": \"\", \"t4\": 6.84, \"params\": 36.9, \"flops\": 104.7},\n \"x\": {\"size\": 640, \"map\": 53.1, \"cpu\": \"\", \"t4\": 11.57, \"params\": 71.3, \"flops\": 189.9},\n },\n },\n \"YOLOv6-3.0\": {\n \"author\": \"Chuyi Li, Lulu Li, Yifei Geng, Hongliang Jiang, Meng Cheng, Bo Zhang, Zaidan Ke, Xiaoming Xu, and Xiangxiang Chu\",\n \"org\": \"Meituan\",\n \"date\": \"2023-01-13\",\n \"arxiv\": \"https://arxiv.org/abs/2301.05586\",\n \"github\": \"https://github.com/meituan/YOLOv6\",\n \"docs\": \"https://docs.ultralytics.com/models/yolov6/\",\n \"performance\": {\n \"n\": {\"size\": 640, \"map\": 37.5, \"cpu\": \"\", \"t4\": 1.17, \"params\": 4.7, \"flops\": 11.4},\n \"s\": {\"size\": 640, \"map\": 45.0, \"cpu\": \"\", \"t4\": 2.66, \"params\": 18.5, \"flops\": 45.3},\n \"m\": {\"size\": 640, \"map\": 50.0, \"cpu\": \"\", \"t4\": 5.28, \"params\": 34.9, \"flops\": 85.8},\n \"l\": {\"size\": 640, \"map\": 52.8, \"cpu\": \"\", \"t4\": 8.95, \"params\": 59.6, \"flops\": 150.7},\n },\n },\n \"YOLOv5\": {\n \"author\": \"Glenn Jocher\",\n \"org\": \"Ultralytics\",\n \"date\": \"2020-06-26\",\n \"arxiv\": None,\n \"github\": \"https://github.com/ultralytics/yolov5\",\n \"docs\": \"https://docs.ultralytics.com/models/yolov5/\",\n \"performance\": {\n \"n\": {\"size\": 640, \"map\": 28.0, \"cpu\": 73.6, \"t4\": 1.12, \"params\": 2.6, \"flops\": 7.7},\n \"s\": {\"size\": 640, \"map\": 37.4, \"cpu\": 120.7, \"t4\": 1.92, \"params\": 9.1, \"flops\": 24.0},\n \"m\": {\"size\": 640, \"map\": 45.4, \"cpu\": 233.9, \"t4\": 4.03, \"params\": 25.1, \"flops\": 64.2},\n \"l\": {\"size\": 640, \"map\": 49.0, \"cpu\": 408.4, \"t4\": 6.61, \"params\": 53.2, \"flops\": 135.0},\n \"x\": {\"size\": 640, \"map\": 50.7, \"cpu\": 763.2, \"t4\": 11.89, \"params\": 97.2, \"flops\": 246.4},\n },\n },\n \"PP-YOLOE+\": {\n \"author\": \"PaddlePaddle Authors\",\n \"org\": \"Baidu\",\n \"date\": \"2022-04-02\",\n \"arxiv\": \"https://arxiv.org/abs/2203.16250\",\n \"github\": \"https://github.com/PaddlePaddle/PaddleDetection/\",\n \"docs\": \"https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.8.1/configs/ppyoloe/README.md\",\n \"performance\": {\n \"t\": {\"size\": 640, \"map\": 39.9, \"cpu\": \"\", \"t4\": 2.84, \"params\": 4.85, \"flops\": 19.15},\n \"s\": {\"size\": 640, \"map\": 43.7, \"cpu\": \"\", \"t4\": 2.62, \"params\": 7.93, \"flops\": 17.36},\n \"m\": {\"size\": 640, \"map\": 49.8, \"cpu\": \"\", \"t4\": 5.56, \"params\": 23.43, \"flops\": 49.91},\n \"l\": {\"size\": 640, \"map\": 52.9, \"cpu\": \"\", \"t4\": 8.36, \"params\": 52.20, \"flops\": 110.07},\n \"x\": {\"size\": 640, \"map\": 54.7, \"cpu\": \"\", \"t4\": 14.3, \"params\": 98.42, \"flops\": 206.59},\n },\n },\n \"DAMO-YOLO\": {\n \"author\": \"Xianzhe Xu, Yiqi Jiang, Weihua Chen, Yilun Huang, Yuan Zhang, and Xiuyu Sun\",\n \"org\": \"Alibaba Group\",\n \"date\": \"2022-11-23\",\n \"arxiv\": \"https://arxiv.org/abs/2211.15444v2\",\n \"github\": \"https://github.com/tinyvision/DAMO-YOLO\",\n \"docs\": \"https://github.com/tinyvision/DAMO-YOLO/blob/master/README.md\",\n \"performance\": {\n \"t\": {\"size\": 640, \"map\": 42.0, \"cpu\": \"\", \"t4\": 2.32, \"params\": 8.5, \"flops\": 18.1},\n \"s\": {\"size\": 640, \"map\": 46.0, \"cpu\": \"\", \"t4\": 3.45, \"params\": 16.3, \"flops\": 37.8},\n \"m\": {\"size\": 640, \"map\": 49.2, \"cpu\": \"\", \"t4\": 5.09, \"params\": 28.2, \"flops\": 61.8},\n \"l\": {\"size\": 640, \"map\": 50.8, \"cpu\": \"\", \"t4\": 7.18, \"params\": 42.1, \"flops\": 97.3},\n },\n },\n \"YOLOX\": {\n \"author\": \"Zheng Ge, Songtao Liu, Feng Wang, Zeming Li, and Jian Sun\",\n \"org\": \"Megvii\",\n \"date\": \"2021-07-18\",\n \"arxiv\": \"https://arxiv.org/abs/2107.08430\",\n \"github\": \"https://github.com/Megvii-BaseDetection/YOLOX\",\n \"docs\": \"https://yolox.readthedocs.io/en/latest/\",\n \"performance\": {\n \"nano\": {\"size\": 416, \"map\": 25.8, \"cpu\": \"\", \"t4\": \"\", \"params\": 0.91, \"flops\": 1.08},\n \"tiny\": {\"size\": 416, \"map\": 32.8, \"cpu\": \"\", \"t4\": \"\", \"params\": 5.06, \"flops\": 6.45},\n \"s\": {\"size\": 640, \"map\": 40.5, \"cpu\": \"\", \"t4\": 2.56, \"params\": 9.0, \"flops\": 26.8},\n \"m\": {\"size\": 640, \"map\": 46.9, \"cpu\": \"\", \"t4\": 5.43, \"params\": 25.3, \"flops\": 73.8},\n \"l\": {\"size\": 640, \"map\": 49.7, \"cpu\": \"\", \"t4\": 9.04, \"params\": 54.2, \"flops\": 155.6},\n \"x\": {\"size\": 640, \"map\": 51.1, \"cpu\": \"\", \"t4\": 16.1, \"params\": 99.1, \"flops\": 281.9},\n },\n },\n \"RTDETRv2\": {\n \"author\": \"Wenyu Lv, Yian Zhao, Qinyao Chang, Kui Huang, Guanzhong Wang, and Yi Liu\",\n \"org\": \"Baidu\",\n \"date\": \"2023-04-17\",\n \"arxiv\": \"https://arxiv.org/abs/2304.08069\",\n \"github\": \"https://github.com/lyuwenyu/RT-DETR/tree/main/rtdetrv2_pytorch\",\n \"docs\": \"https://github.com/lyuwenyu/RT-DETR/tree/main/rtdetrv2_pytorch#readme\",\n \"performance\": {\n \"s\": {\"size\": 640, \"map\": 48.1, \"cpu\": \"\", \"t4\": 5.03, \"params\": 20, \"flops\": 60},\n \"m\": {\"size\": 640, \"map\": 51.9, \"cpu\": \"\", \"t4\": 7.51, \"params\": 36, \"flops\": 100},\n \"l\": {\"size\": 640, \"map\": 53.4, \"cpu\": \"\", \"t4\": 9.76, \"params\": 42, \"flops\": 136},\n \"x\": {\"size\": 640, \"map\": 54.3, \"cpu\": \"\", \"t4\": 15.03, \"params\": 76, \"flops\": 259},\n },\n },\n \"EfficientDet\": {\n \"author\": \"Mingxing Tan, Ruoming Pang, and Quoc V. Le\",\n \"org\": \"Google\",\n \"date\": \"2019-11-20\",\n \"arxiv\": \"https://arxiv.org/abs/1911.09070\",\n \"github\": \"https://github.com/google/automl/tree/master/efficientdet\",\n \"docs\": \"https://github.com/google/automl/tree/master/efficientdet#readme\",\n \"performance\": {\n \"d0\": {\"size\": 640, \"map\": 34.6, \"cpu\": 10.2, \"t4\": 3.92, \"params\": 3.9, \"flops\": 2.54},\n \"d1\": {\"size\": 640, \"map\": 40.5, \"cpu\": 13.5, \"t4\": 7.31, \"params\": 6.6, \"flops\": 6.10},\n \"d2\": {\"size\": 640, \"map\": 43.0, \"cpu\": 17.7, \"t4\": 10.92, \"params\": 8.1, \"flops\": 11.0},\n \"d3\": {\"size\": 640, \"map\": 47.5, \"cpu\": 28.0, \"t4\": 19.59, \"params\": 12.0, \"flops\": 24.9},\n \"d4\": {\"size\": 640, \"map\": 49.7, \"cpu\": 42.8, \"t4\": 33.55, \"params\": 20.7, \"flops\": 55.2},\n \"d5\": {\"size\": 640, \"map\": 51.5, \"cpu\": 72.5, \"t4\": 67.86, \"params\": 33.7, \"flops\": 130.0},\n \"d6\": {\"size\": 640, \"map\": 52.6, \"cpu\": 92.8, \"t4\": 89.29, \"params\": 51.9, \"flops\": 226.0},\n \"d7\": {\"size\": 640, \"map\": 53.7, \"cpu\": 122.0, \"t4\": 128.07, \"params\": 51.9, \"flops\": 325.0},\n },\n },\n \"Gold-YOLO\": {\n \"author\": \"Cheng Wang, Wei He, Ying Nie, Jianyuan Guo, Chuanjian Liu, Yunhe Wang, and Kai Han\",\n \"org\": \"Huawei Noah's Ark Lab\",\n \"date\": \"2023-09-20\",\n \"arxiv\": \"https://arxiv.org/abs/2309.11331\",\n \"github\": \"https://github.com/huawei-noah/Efficient-Computing/tree/master/Detection/Gold-YOLO\",\n \"docs\": \"https://github.com/huawei-noah/Efficient-Computing/blob/master/Detection/Gold-YOLO/README.md\",\n \"performance\": {\n \"n\": {\"size\": 640, \"map\": 39.9, \"cpu\": \"\", \"t4\": 1.66, \"params\": 5.6, \"flops\": 12.1},\n \"s\": {\"size\": 640, \"map\": 46.4, \"cpu\": \"\", \"t4\": 3.43, \"params\": 21.5, \"flops\": 46.0},\n \"m\": {\"size\": 640, \"map\": 51.1, \"cpu\": \"\", \"t4\": 6.43, \"params\": 41.3, \"flops\": 87.5},\n \"l\": {\"size\": 640, \"map\": 53.3, \"cpu\": \"\", \"t4\": 10.64, \"params\": 75.1, \"flops\": 151.7},\n },\n },\n \"D-FINE\": {\n \"author\": \"Yansong Peng, Hebei Li, Peixi Wu, Yueyi Zhang, Xiaoyan Sun, and Feng Wu\",\n \"org\": \"University of Science and Technology of China\",\n \"date\": \"2024-10-17\",\n \"arxiv\": \"https://arxiv.org/abs/2410.13842\",\n \"github\": \"https://github.com/Peterande/D-FINE\",\n \"docs\": \"https://github.com/Peterande/D-FINE/blob/master/README.md\",\n \"performance\": {\n \"n\": {\"size\": 640, \"map\": 42.8, \"cpu\": \"\", \"t4\": 2.28, \"params\": 4, \"flops\": 7},\n \"s\": {\"size\": 640, \"map\": 48.5, \"cpu\": \"\", \"t4\": 4.19, \"params\": 10, \"flops\": 25},\n \"m\": {\"size\": 640, \"map\": 52.3, \"cpu\": \"\", \"t4\": 6.85, \"params\": 19, \"flops\": 57},\n \"l\": {\"size\": 640, \"map\": 54.0, \"cpu\": \"\", \"t4\": 9.50, \"params\": 31, \"flops\": 91},\n \"x\": {\"size\": 640, \"map\": 55.8, \"cpu\": \"\", \"t4\": 15.04, \"params\": 62, \"flops\": 202},\n },\n },\n \"YOLO-World\": {\n \"author\": \"Tianheng Cheng, Lin Song, Yixiao Ge, Wenyu Liu, Xinggang Wang, and Ying Shan\",\n \"org\": \"Tencent AILab Computer Vision Center\",\n \"date\": \"2024-01-30\",\n \"arxiv\": \"https://arxiv.org/abs/2401.17270\",\n \"github\": \"https://github.com/AILab-CVC/YOLO-World\",\n \"docs\": \"https://docs.ultralytics.com/models/yolo-world/\",\n \"performance\": {\n \"s\": {\"size\": 640, \"map\": 46.1, \"cpu\": \"\", \"t4\": 3.46, \"params\": 12.7, \"flops\": 51.0},\n \"m\": {\"size\": 640, \"map\": 51.0, \"cpu\": \"\", \"t4\": 7.26, \"params\": 28.4, \"flops\": 110.5},\n \"l\": {\"size\": 640, \"map\": 53.9, \"cpu\": \"\", \"t4\": 11.00, \"params\": 46.8, \"flops\": 204.5},\n \"x\": {\"size\": 640, \"map\": 54.7, \"cpu\": \"\", \"t4\": 17.24, \"params\": 72.88, \"flops\": 309.3},\n },\n },\n \"RTMDet\": {\n \"author\": \"Chengqi Lyu, Wenwei Zhang, Haian Huang, Yue Zhou, Yudong Wang, Yanyi Liu, Shilong Zhang, and Kai Chen\",\n \"org\": \"OpenMMLab\",\n \"date\": \"2022-12-14\",\n \"arxiv\": \"https://arxiv.org/abs/2212.07784\",\n \"github\": \"https://github.com/open-mmlab/mmdetection/tree/3.x/configs/rtmdet\",\n \"docs\": \"https://github.com/open-mmlab/mmdetection/tree/3.x/configs/rtmdet#readme\",\n \"performance\": {\n \"t\": {\"size\": 640, \"map\": 41.1, \"cpu\": \"\", \"t4\": 2.54, \"params\": 4.8, \"flops\": 8.1},\n \"s\": {\"size\": 640, \"map\": 44.6, \"cpu\": \"\", \"t4\": 3.18, \"params\": 8.89, \"flops\": 14.8},\n \"m\": {\"size\": 640, \"map\": 49.4, \"cpu\": \"\", \"t4\": 6.82, \"params\": 24.71, \"flops\": 39.27},\n \"l\": {\"size\": 640, \"map\": 51.5, \"cpu\": \"\", \"t4\": 11.06, \"params\": 52.3, \"flops\": 80.23},\n \"x\": {\"size\": 640, \"map\": 52.8, \"cpu\": \"\", \"t4\": 19.66, \"params\": 94.86, \"flops\": 141.67},\n },\n },\n \"YOLO-NAS\": {\n \"author\": \"Shay Aharon, Louis-Dupont, Ofri Masad, Kate Yurkova, Lotem Fridman, Lkdci, Eugene Khvedchenya, Ran Rubin, Natan Bagrov, Borys Tymchenko, Tomer Keren, Alexander Zhilko, and Eran-Deci\",\n \"org\": \"Deci AI (acquired by NVIDIA)\",\n \"date\": \"2023-05-03\",\n \"arxiv\": None,\n \"github\": \"https://github.com/Deci-AI/super-gradients/blob/master/YOLONAS.md\",\n \"docs\": \"https://docs.ultralytics.com/models/yolo-nas/\",\n \"performance\": {\n \"s\": {\"size\": 640, \"map\": 47.5, \"cpu\": \"\", \"t4\": 3.09, \"params\": 12.2, \"flops\": 32.8},\n \"m\": {\"size\": 640, \"map\": 51.6, \"cpu\": \"\", \"t4\": 6.07, \"params\": 31.9, \"flops\": 88.9},\n \"l\": {\"size\": 640, \"map\": 52.2, \"cpu\": \"\", \"t4\": 7.84, \"params\": 42.02, \"flops\": 121.09},\n },\n },\n \"FCOS\": {\n \"author\": \"Zhi Tian, Chunhua Shen, Hao Chen, and Tong He\",\n \"org\": \"The University of Adelaide\",\n \"date\": \"2019-04-02\",\n \"arxiv\": \"https://arxiv.org/abs/1904.01355\",\n \"github\": \"https://github.com/tianzhi0549/FCOS/\",\n \"docs\": \"https://github.com/tianzhi0549/FCOS/?tab=readme-ov-file#installation\",\n \"performance\": {\n \"R50\": {\"size\": 800, \"map\": 36.6, \"cpu\": \"\", \"t4\": 15.18, \"params\": 32.3, \"flops\": 250.9},\n \"R101\": {\"size\": 800, \"map\": 39.1, \"cpu\": \"\", \"t4\": 18.91, \"params\": 51.28, \"flops\": 346.1},\n },\n },\n \"SSD\": {\n \"author\": \"Wei Liu, Dragomir Anguelov, Dumitru Erhan, Christian Szegedy, Scott Reed, Cheng-Yang Fu, and Alexander C. Berg\",\n \"org\": \"University of North Carolina at Chapel Hill\",\n \"date\": \"2015-12-08\",\n \"arxiv\": \"https://arxiv.org/abs/1512.02325\",\n \"github\": \"https://github.com/weiliu89/caffe/tree/ssd\",\n \"docs\": \"https://github.com/weiliu89/caffe/tree/ssd?tab=readme-ov-file#installation\",\n \"performance\": {\n \"300\": {\"size\": 300, \"map\": 25.5, \"cpu\": \"\", \"t4\": 3.97, \"params\": 34.3, \"flops\": 68.7},\n \"512\": {\"size\": 512, \"map\": 29.5, \"cpu\": \"\", \"t4\": 8.96, \"params\": 36.0, \"flops\": 197.3},\n },\n },\n \"RTDETRv3\": {\n \"author\": \"Shuo Wang, Chunlong Xia, Feng Lv and Yifeng Shi\",\n \"org\": \"Baidu\",\n \"date\": \"2024-09-13\",\n \"arxiv\": \"https://arxiv.org/abs/2409.08475\",\n \"github\": \"https://github.com/clxia12/RT-DETRv3\",\n \"docs\": \"https://github.com/clxia12/RT-DETRv3/blob/main/README.md\",\n \"performance\": {\n \"s\": {\"size\": 640, \"map\": 48.1, \"cpu\": \"\", \"t4\": 5.03, \"params\": 20, \"flops\": 60},\n \"m\": {\"size\": 640, \"map\": 49.9, \"cpu\": \"\", \"t4\": 7.51, \"params\": 36, \"flops\": 100},\n \"l\": {\"size\": 640, \"map\": 53.4, \"cpu\": \"\", \"t4\": 9.76, \"params\": 42, \"flops\": 136},\n \"x\": {\"size\": 640, \"map\": 54.6, \"cpu\": \"\", \"t4\": 15.03, \"params\": 76, \"flops\": 259},\n },\n },\n \"LWDETR\": {\n \"author\": \"Qiang Chen, Xiangbo Su, Xinyu Zhang, Jian Wang, Jiahui Chen, Yunpeng Shen, Chuchu Han, Ziliang Chen, Weixiang Xu, Fanrong Li, Shan Zhang, Kun Yao, Errui Ding, Gang Zhang, and Jingdong Wang\",\n \"org\": \"Baidu\",\n \"date\": \"2024-06-05\",\n \"arxiv\": \"https://arxiv.org/abs/2406.03459\",\n \"github\": \"https://github.com/Atten4Vis/LW-DETR\",\n \"docs\": \"https://github.com/Atten4Vis/LW-DETR/blob/main/README.md\",\n \"performance\": {\n \"t\": {\"size\": 640, \"map\": 42.6, \"cpu\": \"\", \"t4\": 2.56, \"params\": 12.1, \"flops\": 11.2},\n \"s\": {\"size\": 640, \"map\": 48.0, \"cpu\": \"\", \"t4\": 3.72, \"params\": 14.6, \"flops\": 16.6},\n \"m\": {\"size\": 640, \"map\": 52.5, \"cpu\": \"\", \"t4\": 6.59, \"params\": 28.2, \"flops\": 42.8},\n \"l\": {\"size\": 640, \"map\": 56.1, \"cpu\": \"\", \"t4\": 10.57, \"params\": 46.8, \"flops\": 71.6},\n \"x\": {\"size\": 640, \"map\": 58.3, \"cpu\": \"\", \"t4\": 22.29, \"params\": 118.0, \"flops\": 174.1},\n },\n },\n}", "chunk_type": "variable", "name": "data", "file_path": "ultralytics\\docs\\model_data.py", "start_line": 14, "end_line": 333, "start_col": 0, "end_col": 1, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_data_9b2c3999" }, { "content": "import shutil", "chunk_type": "import", "name": "shutil", "file_path": "ultralytics\\tests\\conftest.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 13, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_shutil_a61ff766" }, { "content": "from pathlib import Path", "chunk_type": "import", "name": "Path", "file_path": "ultralytics\\tests\\conftest.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Path_f9c888c2" }, { "content": "from tests import TMP", "chunk_type": "import", "name": "TMP", "file_path": "ultralytics\\tests\\conftest.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 21, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_TMP_1adbbd4c" }, { "content": "def pytest_addoption(parser):\n \"\"\"Add custom command-line options to pytest.\"\"\"\n parser.addoption(\"--slow\", action=\"store_true\", default=False, help=\"Run slow tests\")", "chunk_type": "function", "name": "pytest_addoption", "file_path": "ultralytics\\tests\\conftest.py", "start_line": 9, "end_line": 11, "start_col": 0, "end_col": 89, "parent_name": null, "docstring": "Add custom command-line options to pytest.", "parameters": [ "parser" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "shutil", "pathlib.Path", "tests.TMP", "ultralytics.utils.torch_utils.init_seeds", "ultralytics.utils.WEIGHTS_DIR" ], "chunk_id": "function_pytest_addoption_6f6b34fb" }, { "content": "def pytest_collection_modifyitems(config, items):\n \"\"\"\n Modify the list of test items to exclude tests marked as slow if the --slow option is not specified.\n\n Args:\n config: The pytest configuration object that provides access to command-line options.\n items (list): The list of collected pytest item objects to be modified based on the presence of --slow option.\n \"\"\"\n if not config.getoption(\"--slow\"):\n # Remove the item entirely from the list of test items if it's marked as 'slow'\n items[:] = [item for item in items if \"slow\" not in item.keywords]", "chunk_type": "function", "name": "pytest_collection_modifyitems", "file_path": "ultralytics\\tests\\conftest.py", "start_line": 14, "end_line": 24, "start_col": 0, "end_col": 74, "parent_name": null, "docstring": "Modify the list of test items to exclude tests marked as slow if the --slow option is not specified.\n\nArgs:\n config: The pytest configuration object that provides access to command-line options.\n items (list): The list of collected pytest item objects to be modified based on the presence of --slow option.", "parameters": [ "config", "items" ], "return_type": null, "decorators": [], "complexity_score": 3, "dependencies": [ "shutil", "pathlib.Path", "tests.TMP", "ultralytics.utils.torch_utils.init_seeds", "ultralytics.utils.WEIGHTS_DIR" ], "chunk_id": "function_pytest_collection_modifyitems_58986c96" }, { "content": "def pytest_sessionstart(session):\n \"\"\"\n Initialize session configurations for pytest.\n\n This function is automatically called by pytest after the 'Session' object has been created but before performing\n test collection. It sets the initial seeds and prepares the temporary directory for the test session.\n\n Args:\n session: The pytest session object.\n \"\"\"\n from ultralytics.utils.torch_utils import init_seeds\n\n init_seeds()\n shutil.rmtree(TMP, ignore_errors=True) # Delete any existing tests/tmp directory\n TMP.mkdir(parents=True, exist_ok=True) # Create a new empty directory", "chunk_type": "function", "name": "pytest_sessionstart", "file_path": "ultralytics\\tests\\conftest.py", "start_line": 27, "end_line": 41, "start_col": 0, "end_col": 42, "parent_name": null, "docstring": "Initialize session configurations for pytest.\n\nThis function is automatically called by pytest after the 'Session' object has been created but before performing\ntest collection. It sets the initial seeds and prepares the temporary directory for the test session.\n\nArgs:\n session: The pytest session object.", "parameters": [ "session" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "shutil", "pathlib.Path", "tests.TMP", "ultralytics.utils.torch_utils.init_seeds", "ultralytics.utils.WEIGHTS_DIR" ], "chunk_id": "function_pytest_sessionstart_15fd39ce" }, { "content": "def pytest_terminal_summary(terminalreporter, exitstatus, config):\n \"\"\"\n Cleanup operations after pytest session.\n\n This function is automatically called by pytest at the end of the entire test session. It removes certain files\n and directories used during testing.\n\n Args:\n terminalreporter: The terminal reporter object used for terminal output.\n exitstatus (int): The exit status of the test run.\n config: The pytest config object.\n \"\"\"\n from ultralytics.utils import WEIGHTS_DIR\n\n # Remove files\n models = [path for x in {\"*.onnx\", \"*.torchscript\"} for path in WEIGHTS_DIR.rglob(x)]\n for file in [\"decelera_portrait_min.mov\", \"bus.jpg\", \"yolo11n.onnx\", \"yolo11n.torchscript\"] + models:\n Path(file).unlink(missing_ok=True)\n\n # Remove directories\n models = [path for x in {\"*.mlpackage\", \"*_openvino_model\"} for path in WEIGHTS_DIR.rglob(x)]\n for directory in [WEIGHTS_DIR / \"path with spaces\", TMP.parents[1] / \".pytest_cache\", TMP] + models:\n shutil.rmtree(directory, ignore_errors=True)", "chunk_type": "function", "name": "pytest_terminal_summary", "file_path": "ultralytics\\tests\\conftest.py", "start_line": 44, "end_line": 66, "start_col": 0, "end_col": 52, "parent_name": null, "docstring": "Cleanup operations after pytest session.\n\nThis function is automatically called by pytest at the end of the entire test session. It removes certain files\nand directories used during testing.\n\nArgs:\n terminalreporter: The terminal reporter object used for terminal output.\n exitstatus (int): The exit status of the test run.\n config: The pytest config object.", "parameters": [ "terminalreporter", "exitstatus", "config" ], "return_type": null, "decorators": [], "complexity_score": 5, "dependencies": [ "shutil", "pathlib.Path", "tests.TMP", "ultralytics.utils.torch_utils.init_seeds", "ultralytics.utils.WEIGHTS_DIR" ], "chunk_id": "function_pytest_terminal_summary_ae822bd5" }, { "content": "import subprocess", "chunk_type": "import", "name": "subprocess", "file_path": "ultralytics\\tests\\test_cli.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 17, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_subprocess_ce47048d" }, { "content": "import pytest", "chunk_type": "import", "name": "pytest", "file_path": "ultralytics\\tests\\test_cli.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 13, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_pytest_7e711d04" }, { "content": "from PIL import Image", "chunk_type": "import", "name": "Image", "file_path": "ultralytics\\tests\\test_cli.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 21, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Image_c0a6e162" }, { "content": "from tests import CUDA_DEVICE_COUNT, CUDA_IS_AVAILABLE, MODELS, TASK_MODEL_DATA", "chunk_type": "import", "name": "CUDA_DEVICE_COUNT, CUDA_IS_AVAILABLE, MODELS, TASK_MODEL_DATA", "file_path": "ultralytics\\tests\\test_cli.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 79, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_CUDA_DEVICE_COUNT, CUDA_IS_AVAILABLE, MODELS, TASK_MODEL_DATA_40d6bb41" }, { "content": "from ultralytics.utils import ARM64, ASSETS, LINUX, WEIGHTS_DIR, checks", "chunk_type": "import", "name": "ARM64, ASSETS, LINUX, WEIGHTS_DIR, checks", "file_path": "ultralytics\\tests\\test_cli.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 71, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_ARM64, ASSETS, LINUX, WEIGHTS_DIR, checks_f6a45a09" }, { "content": "from ultralytics.utils.torch_utils import TORCH_1_9", "chunk_type": "import", "name": "TORCH_1_9", "file_path": "ultralytics\\tests\\test_cli.py", "start_line": 10, "end_line": 10, "start_col": 0, "end_col": 51, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_TORCH_1_9_d739b85b" }, { "content": "def run(cmd: str) -> None:\n \"\"\"Execute a shell command using subprocess.\"\"\"\n subprocess.run(cmd.split(), check=True)", "chunk_type": "function", "name": "run", "file_path": "ultralytics\\tests\\test_cli.py", "start_line": 13, "end_line": 15, "start_col": 0, "end_col": 43, "parent_name": null, "docstring": "Execute a shell command using subprocess.", "parameters": [ "cmd: str" ], "return_type": "None", "decorators": [], "complexity_score": 1, "dependencies": [ "subprocess", "pytest", "PIL.Image", "tests.CUDA_DEVICE_COUNT", "tests.CUDA_IS_AVAILABLE", "tests.MODELS", "tests.TASK_MODEL_DATA", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.LINUX", "ultralytics.utils.WEIGHTS_DIR", "ultralytics.utils.checks", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.FastSAM", "ultralytics.models.sam.Predictor", "ultralytics.SAM" ], "chunk_id": "function_run_d02283bd" }, { "content": "def test_special_modes() -> None:\n \"\"\"Test various special command-line modes for YOLO functionality.\"\"\"\n run(\"yolo help\")\n run(\"yolo checks\")\n run(\"yolo version\")\n run(\"yolo settings reset\")\n run(\"yolo cfg\")", "chunk_type": "function", "name": "test_special_modes", "file_path": "ultralytics\\tests\\test_cli.py", "start_line": 18, "end_line": 24, "start_col": 0, "end_col": 19, "parent_name": null, "docstring": "Test various special command-line modes for YOLO functionality.", "parameters": [], "return_type": "None", "decorators": [], "complexity_score": 1, "dependencies": [ "subprocess", "pytest", "PIL.Image", "tests.CUDA_DEVICE_COUNT", "tests.CUDA_IS_AVAILABLE", "tests.MODELS", "tests.TASK_MODEL_DATA", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.LINUX", "ultralytics.utils.WEIGHTS_DIR", "ultralytics.utils.checks", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.FastSAM", "ultralytics.models.sam.Predictor", "ultralytics.SAM" ], "chunk_id": "function_test_special_modes_ad3732d6" }, { "content": "def test_train(task: str, model: str, data: str) -> None:\n \"\"\"Test YOLO training for different tasks, models, and datasets.\"\"\"\n run(f\"yolo train {task} model={model} data={data} imgsz=32 epochs=1 cache=disk\")", "chunk_type": "function", "name": "test_train", "file_path": "ultralytics\\tests\\test_cli.py", "start_line": 28, "end_line": 30, "start_col": 0, "end_col": 84, "parent_name": null, "docstring": "Test YOLO training for different tasks, models, and datasets.", "parameters": [ "task: str", "model: str", "data: str" ], "return_type": "None", "decorators": [ "pytest.mark.parametrize('task,model,data', TASK_MODEL_DATA)" ], "complexity_score": 1, "dependencies": [ "subprocess", "pytest", "PIL.Image", "tests.CUDA_DEVICE_COUNT", "tests.CUDA_IS_AVAILABLE", "tests.MODELS", "tests.TASK_MODEL_DATA", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.LINUX", "ultralytics.utils.WEIGHTS_DIR", "ultralytics.utils.checks", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.FastSAM", "ultralytics.models.sam.Predictor", "ultralytics.SAM" ], "chunk_id": "function_test_train_4a0ae3e7" }, { "content": "def test_val(task: str, model: str, data: str) -> None:\n \"\"\"Test YOLO validation process for specified task, model, and data using a shell command.\"\"\"\n run(f\"yolo val {task} model={model} data={data} imgsz=32 save_txt save_json\")", "chunk_type": "function", "name": "test_val", "file_path": "ultralytics\\tests\\test_cli.py", "start_line": 34, "end_line": 36, "start_col": 0, "end_col": 81, "parent_name": null, "docstring": "Test YOLO validation process for specified task, model, and data using a shell command.", "parameters": [ "task: str", "model: str", "data: str" ], "return_type": "None", "decorators": [ "pytest.mark.parametrize('task,model,data', TASK_MODEL_DATA)" ], "complexity_score": 1, "dependencies": [ "subprocess", "pytest", "PIL.Image", "tests.CUDA_DEVICE_COUNT", "tests.CUDA_IS_AVAILABLE", "tests.MODELS", "tests.TASK_MODEL_DATA", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.LINUX", "ultralytics.utils.WEIGHTS_DIR", "ultralytics.utils.checks", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.FastSAM", "ultralytics.models.sam.Predictor", "ultralytics.SAM" ], "chunk_id": "function_test_val_60b053ef" }, { "content": "def test_predict(task: str, model: str, data: str) -> None:\n \"\"\"Test YOLO prediction on provided sample assets for specified task and model.\"\"\"\n run(f\"yolo {task} predict model={model} source={ASSETS} imgsz=32 save save_crop save_txt\")", "chunk_type": "function", "name": "test_predict", "file_path": "ultralytics\\tests\\test_cli.py", "start_line": 40, "end_line": 42, "start_col": 0, "end_col": 94, "parent_name": null, "docstring": "Test YOLO prediction on provided sample assets for specified task and model.", "parameters": [ "task: str", "model: str", "data: str" ], "return_type": "None", "decorators": [ "pytest.mark.parametrize('task,model,data', TASK_MODEL_DATA)" ], "complexity_score": 1, "dependencies": [ "subprocess", "pytest", "PIL.Image", "tests.CUDA_DEVICE_COUNT", "tests.CUDA_IS_AVAILABLE", "tests.MODELS", "tests.TASK_MODEL_DATA", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.LINUX", "ultralytics.utils.WEIGHTS_DIR", "ultralytics.utils.checks", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.FastSAM", "ultralytics.models.sam.Predictor", "ultralytics.SAM" ], "chunk_id": "function_test_predict_cda3887d" }, { "content": "def test_export(model: str) -> None:\n \"\"\"Test exporting a YOLO model to TorchScript format.\"\"\"\n run(f\"yolo export model={model} format=torchscript imgsz=32\")", "chunk_type": "function", "name": "test_export", "file_path": "ultralytics\\tests\\test_cli.py", "start_line": 46, "end_line": 48, "start_col": 0, "end_col": 65, "parent_name": null, "docstring": "Test exporting a YOLO model to TorchScript format.", "parameters": [ "model: str" ], "return_type": "None", "decorators": [ "pytest.mark.parametrize('model', MODELS)" ], "complexity_score": 1, "dependencies": [ "subprocess", "pytest", "PIL.Image", "tests.CUDA_DEVICE_COUNT", "tests.CUDA_IS_AVAILABLE", "tests.MODELS", "tests.TASK_MODEL_DATA", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.LINUX", "ultralytics.utils.WEIGHTS_DIR", "ultralytics.utils.checks", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.FastSAM", "ultralytics.models.sam.Predictor", "ultralytics.SAM" ], "chunk_id": "function_test_export_82484e5d" }, { "content": "def test_rtdetr(task: str = \"detect\", model: str = \"yolov8n-rtdetr.yaml\", data: str = \"coco8.yaml\") -> None:\n \"\"\"Test the RTDETR functionality within Ultralytics for detection tasks using specified model and data.\"\"\"\n # Warning: must use imgsz=640 (note also add comma, spaces, fraction=0.25 args to test single-image training)\n run(f\"yolo train {task} model={model} data={data} --imgsz= 160 epochs =1, cache = disk fraction=0.25\") # spaces\n run(f\"yolo predict {task} model={model} source={ASSETS / 'bus.jpg'} imgsz=160 save save_crop save_txt\")\n if TORCH_1_9:\n weights = WEIGHTS_DIR / \"rtdetr-l.pt\"\n run(f\"yolo predict {task} model={weights} source={ASSETS / 'bus.jpg'} imgsz=160 save save_crop save_txt\")\n run(f\"yolo train {task} model={weights} epochs=1 imgsz=160 cache=disk data=coco8.yaml\")", "chunk_type": "function", "name": "test_rtdetr", "file_path": "ultralytics\\tests\\test_cli.py", "start_line": 51, "end_line": 59, "start_col": 0, "end_col": 95, "parent_name": null, "docstring": "Test the RTDETR functionality within Ultralytics for detection tasks using specified model and data.", "parameters": [ "task: str", "model: str", "data: str" ], "return_type": "None", "decorators": [], "complexity_score": 2, "dependencies": [ "subprocess", "pytest", "PIL.Image", "tests.CUDA_DEVICE_COUNT", "tests.CUDA_IS_AVAILABLE", "tests.MODELS", "tests.TASK_MODEL_DATA", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.LINUX", "ultralytics.utils.WEIGHTS_DIR", "ultralytics.utils.checks", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.FastSAM", "ultralytics.models.sam.Predictor", "ultralytics.SAM" ], "chunk_id": "function_test_rtdetr_eb018143" }, { "content": "def test_fastsam(\n task: str = \"segment\", model: str = WEIGHTS_DIR / \"FastSAM-s.pt\", data: str = \"coco8-seg.yaml\"\n) -> None:\n \"\"\"Test FastSAM model for segmenting objects in images using various prompts within Ultralytics.\"\"\"\n source = ASSETS / \"bus.jpg\"\n\n run(f\"yolo segment val {task} model={model} data={data} imgsz=32\")\n run(f\"yolo segment predict model={model} source={source} imgsz=32 save save_crop save_txt\")\n\n from ultralytics import FastSAM\n from ultralytics.models.sam import Predictor\n\n # Create a FastSAM model\n sam_model = FastSAM(model) # or FastSAM-x.pt\n\n # Run inference on an image\n for s in (source, Image.open(source)):\n everything_results = sam_model(s, device=\"cpu\", retina_masks=True, imgsz=320, conf=0.4, iou=0.9)\n\n # Remove small regions\n new_masks, _ = Predictor.remove_small_regions(everything_results[0].masks.data, min_area=20)\n\n # Run inference with bboxes and points and texts prompt at the same time\n sam_model(source, bboxes=[439, 437, 524, 709], points=[[200, 200]], labels=[1], texts=\"a photo of a dog\")", "chunk_type": "function", "name": "test_fastsam", "file_path": "ultralytics\\tests\\test_cli.py", "start_line": 67, "end_line": 90, "start_col": 0, "end_col": 113, "parent_name": null, "docstring": "Test FastSAM model for segmenting objects in images using various prompts within Ultralytics.", "parameters": [ "task: str", "model: str", "data: str" ], "return_type": "None", "decorators": [ "pytest.mark.skipif(checks.IS_PYTHON_3_12, reason='MobileSAM with CLIP is not supported in Python 3.12')", "pytest.mark.skipif(checks.IS_PYTHON_3_8 and LINUX and ARM64, reason='MobileSAM with CLIP is not supported in Python 3.8 and aarch64 Linux')" ], "complexity_score": 2, "dependencies": [ "subprocess", "pytest", "PIL.Image", "tests.CUDA_DEVICE_COUNT", "tests.CUDA_IS_AVAILABLE", "tests.MODELS", "tests.TASK_MODEL_DATA", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.LINUX", "ultralytics.utils.WEIGHTS_DIR", "ultralytics.utils.checks", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.FastSAM", "ultralytics.models.sam.Predictor", "ultralytics.SAM" ], "chunk_id": "function_test_fastsam_f9e90a8a" }, { "content": "def test_mobilesam() -> None:\n \"\"\"Test MobileSAM segmentation with point prompts using Ultralytics.\"\"\"\n from ultralytics import SAM\n\n # Load the model\n model = SAM(WEIGHTS_DIR / \"mobile_sam.pt\")\n\n # Source\n source = ASSETS / \"zidane.jpg\"\n\n # Predict a segment based on a 1D point prompt and 1D labels.\n model.predict(source, points=[900, 370], labels=[1])\n\n # Predict a segment based on 3D points and 2D labels (multiple points per object).\n model.predict(source, points=[[[900, 370], [1000, 100]]], labels=[[1, 1]])\n\n # Predict a segment based on a box prompt\n model.predict(source, bboxes=[439, 437, 524, 709], save=True)", "chunk_type": "function", "name": "test_mobilesam", "file_path": "ultralytics\\tests\\test_cli.py", "start_line": 93, "end_line": 110, "start_col": 0, "end_col": 65, "parent_name": null, "docstring": "Test MobileSAM segmentation with point prompts using Ultralytics.", "parameters": [], "return_type": "None", "decorators": [], "complexity_score": 1, "dependencies": [ "subprocess", "pytest", "PIL.Image", "tests.CUDA_DEVICE_COUNT", "tests.CUDA_IS_AVAILABLE", "tests.MODELS", "tests.TASK_MODEL_DATA", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.LINUX", "ultralytics.utils.WEIGHTS_DIR", "ultralytics.utils.checks", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.FastSAM", "ultralytics.models.sam.Predictor", "ultralytics.SAM" ], "chunk_id": "function_test_mobilesam_f6b13e42" }, { "content": "def test_train_gpu(task: str, model: str, data: str) -> None:\n \"\"\"Test YOLO training on GPU(s) for various tasks and models.\"\"\"\n run(f\"yolo train {task} model={model} data={data} imgsz=32 epochs=1 device=0\") # single GPU\n run(f\"yolo train {task} model={model} data={data} imgsz=32 epochs=1 device=0,1\") # multi GPU", "chunk_type": "function", "name": "test_train_gpu", "file_path": "ultralytics\\tests\\test_cli.py", "start_line": 121, "end_line": 124, "start_col": 0, "end_col": 84, "parent_name": null, "docstring": "Test YOLO training on GPU(s) for various tasks and models.", "parameters": [ "task: str", "model: str", "data: str" ], "return_type": "None", "decorators": [ "pytest.mark.slow", "pytest.mark.parametrize('task,model,data', TASK_MODEL_DATA)", "pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason='CUDA is not available')", "pytest.mark.skipif(CUDA_DEVICE_COUNT < 2, reason='DDP is not available')" ], "complexity_score": 1, "dependencies": [ "subprocess", "pytest", "PIL.Image", "tests.CUDA_DEVICE_COUNT", "tests.CUDA_IS_AVAILABLE", "tests.MODELS", "tests.TASK_MODEL_DATA", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.LINUX", "ultralytics.utils.WEIGHTS_DIR", "ultralytics.utils.checks", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.FastSAM", "ultralytics.models.sam.Predictor", "ultralytics.SAM" ], "chunk_id": "function_test_train_gpu_7f398aec" }, { "content": "def test_solutions(solution: str) -> None:\n \"\"\"Test yolo solutions command-line modes.\"\"\"\n run(f\"yolo solutions {solution} verbose=False\")", "chunk_type": "function", "name": "test_solutions", "file_path": "ultralytics\\tests\\test_cli.py", "start_line": 131, "end_line": 133, "start_col": 0, "end_col": 51, "parent_name": null, "docstring": "Test yolo solutions command-line modes.", "parameters": [ "solution: str" ], "return_type": "None", "decorators": [ "pytest.mark.parametrize('solution', ['count', 'blur', 'workout', 'heatmap', 'isegment', 'visioneye', 'speed', 'queue', 'analytics', 'trackzone'])" ], "complexity_score": 1, "dependencies": [ "subprocess", "pytest", "PIL.Image", "tests.CUDA_DEVICE_COUNT", "tests.CUDA_IS_AVAILABLE", "tests.MODELS", "tests.TASK_MODEL_DATA", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.LINUX", "ultralytics.utils.WEIGHTS_DIR", "ultralytics.utils.checks", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.FastSAM", "ultralytics.models.sam.Predictor", "ultralytics.SAM" ], "chunk_id": "function_test_solutions_7f4245c8" }, { "content": "from itertools import product", "chunk_type": "import", "name": "product", "file_path": "ultralytics\\tests\\test_cuda.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 29, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_product_fb1704d2" }, { "content": "from pathlib import Path", "chunk_type": "import", "name": "Path", "file_path": "ultralytics\\tests\\test_cuda.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Path_aa6929ae" }, { "content": "import pytest", "chunk_type": "import", "name": "pytest", "file_path": "ultralytics\\tests\\test_cuda.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 13, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_pytest_2fc138f5" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\tests\\test_cuda.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_4421c21f" }, { "content": "from tests import CUDA_DEVICE_COUNT, CUDA_IS_AVAILABLE, MODEL, SOURCE", "chunk_type": "import", "name": "CUDA_DEVICE_COUNT, CUDA_IS_AVAILABLE, MODEL, SOURCE", "file_path": "ultralytics\\tests\\test_cuda.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 69, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_CUDA_DEVICE_COUNT, CUDA_IS_AVAILABLE, MODEL, SOURCE_23dc17a9" }, { "content": "from ultralytics import YOLO", "chunk_type": "import", "name": "YOLO", "file_path": "ultralytics\\tests\\test_cuda.py", "start_line": 10, "end_line": 10, "start_col": 0, "end_col": 28, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_YOLO_0dfeacbf" }, { "content": "from ultralytics.cfg import TASK2DATA, TASK2MODEL, TASKS", "chunk_type": "import", "name": "TASK2DATA, TASK2MODEL, TASKS", "file_path": "ultralytics\\tests\\test_cuda.py", "start_line": 11, "end_line": 11, "start_col": 0, "end_col": 56, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_TASK2DATA, TASK2MODEL, TASKS_3f3bc108" }, { "content": "from ultralytics.utils import ASSETS, IS_JETSON, WEIGHTS_DIR", "chunk_type": "import", "name": "ASSETS, IS_JETSON, WEIGHTS_DIR", "file_path": "ultralytics\\tests\\test_cuda.py", "start_line": 12, "end_line": 12, "start_col": 0, "end_col": 60, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_ASSETS, IS_JETSON, WEIGHTS_DIR_d61bc294" }, { "content": "from ultralytics.utils.autodevice import GPUInfo", "chunk_type": "import", "name": "GPUInfo", "file_path": "ultralytics\\tests\\test_cuda.py", "start_line": 13, "end_line": 13, "start_col": 0, "end_col": 48, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_GPUInfo_ea9ae67b" }, { "content": "from ultralytics.utils.checks import check_amp", "chunk_type": "import", "name": "check_amp", "file_path": "ultralytics\\tests\\test_cuda.py", "start_line": 14, "end_line": 14, "start_col": 0, "end_col": 46, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_check_amp_a674ed51" }, { "content": "from ultralytics.utils.torch_utils import TORCH_1_13", "chunk_type": "import", "name": "TORCH_1_13", "file_path": "ultralytics\\tests\\test_cuda.py", "start_line": 15, "end_line": 15, "start_col": 0, "end_col": 52, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_TORCH_1_13_3d38ddb2" }, { "content": "DEVICES = []", "chunk_type": "variable", "name": "DEVICES", "file_path": "ultralytics\\tests\\test_cuda.py", "start_line": 18, "end_line": 18, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_DEVICES_d4fcbd07" }, { "content": "def test_checks():\n \"\"\"Validate CUDA settings against torch CUDA functions.\"\"\"\n assert torch.cuda.is_available() == CUDA_IS_AVAILABLE\n assert torch.cuda.device_count() == CUDA_DEVICE_COUNT", "chunk_type": "function", "name": "test_checks", "file_path": "ultralytics\\tests\\test_cuda.py", "start_line": 33, "end_line": 36, "start_col": 0, "end_col": 57, "parent_name": null, "docstring": "Validate CUDA settings against torch CUDA functions.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "itertools.product", "pathlib.Path", "pytest", "torch", "tests.CUDA_DEVICE_COUNT", "tests.CUDA_IS_AVAILABLE", "tests.MODEL", "tests.SOURCE", "ultralytics.YOLO", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASK2MODEL", "ultralytics.cfg.TASKS", "ultralytics.utils.ASSETS", "ultralytics.utils.IS_JETSON", "ultralytics.utils.WEIGHTS_DIR", "ultralytics.utils.autodevice.GPUInfo", "ultralytics.utils.checks.check_amp", "ultralytics.utils.torch_utils.TORCH_1_13", "os", "ultralytics.utils.autobatch.check_train_batch_size", "ultralytics.utils.benchmarks.ProfileModels", "ultralytics.SAM", "ultralytics.models.sam.Predictor" ], "chunk_id": "function_test_checks_2d43d47e" }, { "content": "def test_amp():\n \"\"\"Test AMP training checks.\"\"\"\n model = YOLO(\"yolo11n.pt\").model.to(f\"cuda:{DEVICES[0]}\")\n assert check_amp(model)", "chunk_type": "function", "name": "test_amp", "file_path": "ultralytics\\tests\\test_cuda.py", "start_line": 40, "end_line": 43, "start_col": 0, "end_col": 27, "parent_name": null, "docstring": "Test AMP training checks.", "parameters": [], "return_type": null, "decorators": [ "pytest.mark.skipif(not DEVICES, reason='No CUDA devices available')" ], "complexity_score": 1, "dependencies": [ "itertools.product", "pathlib.Path", "pytest", "torch", "tests.CUDA_DEVICE_COUNT", "tests.CUDA_IS_AVAILABLE", "tests.MODEL", "tests.SOURCE", "ultralytics.YOLO", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASK2MODEL", "ultralytics.cfg.TASKS", "ultralytics.utils.ASSETS", "ultralytics.utils.IS_JETSON", "ultralytics.utils.WEIGHTS_DIR", "ultralytics.utils.autodevice.GPUInfo", "ultralytics.utils.checks.check_amp", "ultralytics.utils.torch_utils.TORCH_1_13", "os", "ultralytics.utils.autobatch.check_train_batch_size", "ultralytics.utils.benchmarks.ProfileModels", "ultralytics.SAM", "ultralytics.models.sam.Predictor" ], "chunk_id": "function_test_amp_3909ae17" }, { "content": "def test_export_onnx_matrix(task, dynamic, int8, half, batch, simplify, nms):\n \"\"\"Test YOLO exports to ONNX format with various configurations and parameters.\"\"\"\n file = YOLO(TASK2MODEL[task]).export(\n format=\"onnx\",\n imgsz=32,\n dynamic=dynamic,\n int8=int8,\n half=half,\n batch=batch,\n simplify=simplify,\n nms=nms,\n device=DEVICES[0],\n )\n YOLO(file)([SOURCE] * batch, imgsz=64 if dynamic else 32, device=DEVICES[0]) # exported model inference\n Path(file).unlink() # cleanup", "chunk_type": "function", "name": "test_export_onnx_matrix", "file_path": "ultralytics\\tests\\test_cuda.py", "start_line": 60, "end_line": 74, "start_col": 0, "end_col": 23, "parent_name": null, "docstring": "Test YOLO exports to ONNX format with various configurations and parameters.", "parameters": [ "task", "dynamic", "int8", "half", "batch", "simplify", "nms" ], "return_type": null, "decorators": [ "pytest.mark.slow", "pytest.mark.skipif(not DEVICES, reason='No CUDA devices available')", "pytest.mark.parametrize('task, dynamic, int8, half, batch, simplify, nms', [(task, dynamic, int8, half, batch, simplify, nms) for task, dynamic, int8, half, batch, simplify, nms in product(TASKS, [True, False], [False], [False], [1, 2], [True, False], [True, False]) if not (int8 and half or (task == 'classify' and nms) or (task == 'obb' and nms and (not TORCH_1_13 or IS_JETSON)))])" ], "complexity_score": 2, "dependencies": [ "itertools.product", "pathlib.Path", "pytest", "torch", "tests.CUDA_DEVICE_COUNT", "tests.CUDA_IS_AVAILABLE", "tests.MODEL", "tests.SOURCE", "ultralytics.YOLO", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASK2MODEL", "ultralytics.cfg.TASKS", "ultralytics.utils.ASSETS", "ultralytics.utils.IS_JETSON", "ultralytics.utils.WEIGHTS_DIR", "ultralytics.utils.autodevice.GPUInfo", "ultralytics.utils.checks.check_amp", "ultralytics.utils.torch_utils.TORCH_1_13", "os", "ultralytics.utils.autobatch.check_train_batch_size", "ultralytics.utils.benchmarks.ProfileModels", "ultralytics.SAM", "ultralytics.models.sam.Predictor" ], "chunk_id": "function_test_export_onnx_matrix_1b8e788f" }, { "content": "def test_export_engine_matrix(task, dynamic, int8, half, batch):\n \"\"\"Test YOLO model export to TensorRT format for various configurations and run inference.\"\"\"\n file = YOLO(TASK2MODEL[task]).export(\n format=\"engine\",\n imgsz=32,\n dynamic=dynamic,\n int8=int8,\n half=half,\n batch=batch,\n data=TASK2DATA[task],\n workspace=1, # reduce workspace GB for less resource utilization during testing\n simplify=True,\n device=DEVICES[0],\n )\n YOLO(file)([SOURCE] * batch, imgsz=64 if dynamic else 32, device=DEVICES[0]) # exported model inference\n Path(file).unlink() # cleanup\n Path(file).with_suffix(\".cache\").unlink() if int8 else None # cleanup INT8 cache", "chunk_type": "function", "name": "test_export_engine_matrix", "file_path": "ultralytics\\tests\\test_cuda.py", "start_line": 90, "end_line": 106, "start_col": 0, "end_col": 63, "parent_name": null, "docstring": "Test YOLO model export to TensorRT format for various configurations and run inference.", "parameters": [ "task", "dynamic", "int8", "half", "batch" ], "return_type": null, "decorators": [ "pytest.mark.slow", "pytest.mark.skipif(True, reason='CUDA export tests disabled pending additional Ultralytics GPU server availability')", "pytest.mark.skipif(not DEVICES, reason='No CUDA devices available')", "pytest.mark.parametrize('task, dynamic, int8, half, batch', [(task, dynamic, int8, half, batch) for task, dynamic, int8, half, batch in product(TASKS, [True], [True], [False], [2]) if not (int8 and half)])" ], "complexity_score": 2, "dependencies": [ "itertools.product", "pathlib.Path", "pytest", "torch", "tests.CUDA_DEVICE_COUNT", "tests.CUDA_IS_AVAILABLE", "tests.MODEL", "tests.SOURCE", "ultralytics.YOLO", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASK2MODEL", "ultralytics.cfg.TASKS", "ultralytics.utils.ASSETS", "ultralytics.utils.IS_JETSON", "ultralytics.utils.WEIGHTS_DIR", "ultralytics.utils.autodevice.GPUInfo", "ultralytics.utils.checks.check_amp", "ultralytics.utils.torch_utils.TORCH_1_13", "os", "ultralytics.utils.autobatch.check_train_batch_size", "ultralytics.utils.benchmarks.ProfileModels", "ultralytics.SAM", "ultralytics.models.sam.Predictor" ], "chunk_id": "function_test_export_engine_matrix_26d34800" }, { "content": "def test_train():\n \"\"\"Test model training on a minimal dataset using available CUDA devices.\"\"\"\n import os\n\n device = tuple(DEVICES) if len(DEVICES) > 1 else DEVICES[0]\n results = YOLO(MODEL).train(data=\"coco8.yaml\", imgsz=64, epochs=1, device=device) # requires imgsz>=64\n # NVIDIA Jetson only has one GPU and therefore skipping checks\n if not IS_JETSON:\n visible = eval(os.environ[\"CUDA_VISIBLE_DEVICES\"])\n assert visible == device, f\"Passed GPUs '{device}', but used GPUs '{visible}'\"\n assert (\n (results is None) if len(DEVICES) > 1 else (results is not None)\n ) # DDP returns None, single-GPU returns metrics", "chunk_type": "function", "name": "test_train", "file_path": "ultralytics\\tests\\test_cuda.py", "start_line": 110, "end_line": 122, "start_col": 0, "end_col": 9, "parent_name": null, "docstring": "Test model training on a minimal dataset using available CUDA devices.", "parameters": [], "return_type": null, "decorators": [ "pytest.mark.skipif(not DEVICES, reason='No CUDA devices available')" ], "complexity_score": 2, "dependencies": [ "itertools.product", "pathlib.Path", "pytest", "torch", "tests.CUDA_DEVICE_COUNT", "tests.CUDA_IS_AVAILABLE", "tests.MODEL", "tests.SOURCE", "ultralytics.YOLO", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASK2MODEL", "ultralytics.cfg.TASKS", "ultralytics.utils.ASSETS", "ultralytics.utils.IS_JETSON", "ultralytics.utils.WEIGHTS_DIR", "ultralytics.utils.autodevice.GPUInfo", "ultralytics.utils.checks.check_amp", "ultralytics.utils.torch_utils.TORCH_1_13", "os", "ultralytics.utils.autobatch.check_train_batch_size", "ultralytics.utils.benchmarks.ProfileModels", "ultralytics.SAM", "ultralytics.models.sam.Predictor" ], "chunk_id": "function_test_train_14ca9189" }, { "content": "def test_predict_multiple_devices():\n \"\"\"Validate model prediction consistency across CPU and CUDA devices.\"\"\"\n model = YOLO(\"yolo11n.pt\")\n\n # Test CPU\n model = model.cpu()\n assert str(model.device) == \"cpu\"\n _ = model(SOURCE)\n assert str(model.device) == \"cpu\"\n\n # Test CUDA\n cuda_device = f\"cuda:{DEVICES[0]}\"\n model = model.to(cuda_device)\n assert str(model.device) == cuda_device\n _ = model(SOURCE)\n assert str(model.device) == cuda_device\n\n # Test CPU again\n model = model.cpu()\n assert str(model.device) == \"cpu\"\n _ = model(SOURCE)\n assert str(model.device) == \"cpu\"\n\n # Test CUDA again\n model = model.to(cuda_device)\n assert str(model.device) == cuda_device\n _ = model(SOURCE)\n assert str(model.device) == cuda_device", "chunk_type": "function", "name": "test_predict_multiple_devices", "file_path": "ultralytics\\tests\\test_cuda.py", "start_line": 127, "end_line": 154, "start_col": 0, "end_col": 43, "parent_name": null, "docstring": "Validate model prediction consistency across CPU and CUDA devices.", "parameters": [], "return_type": null, "decorators": [ "pytest.mark.slow", "pytest.mark.skipif(not DEVICES, reason='No CUDA devices available')" ], "complexity_score": 1, "dependencies": [ "itertools.product", "pathlib.Path", "pytest", "torch", "tests.CUDA_DEVICE_COUNT", "tests.CUDA_IS_AVAILABLE", "tests.MODEL", "tests.SOURCE", "ultralytics.YOLO", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASK2MODEL", "ultralytics.cfg.TASKS", "ultralytics.utils.ASSETS", "ultralytics.utils.IS_JETSON", "ultralytics.utils.WEIGHTS_DIR", "ultralytics.utils.autodevice.GPUInfo", "ultralytics.utils.checks.check_amp", "ultralytics.utils.torch_utils.TORCH_1_13", "os", "ultralytics.utils.autobatch.check_train_batch_size", "ultralytics.utils.benchmarks.ProfileModels", "ultralytics.SAM", "ultralytics.models.sam.Predictor" ], "chunk_id": "function_test_predict_multiple_devices_b467d436" }, { "content": "def test_autobatch():\n \"\"\"Check optimal batch size for YOLO model training using autobatch utility.\"\"\"\n from ultralytics.utils.autobatch import check_train_batch_size\n\n check_train_batch_size(YOLO(MODEL).model.to(f\"cuda:{DEVICES[0]}\"), imgsz=128, amp=True)", "chunk_type": "function", "name": "test_autobatch", "file_path": "ultralytics\\tests\\test_cuda.py", "start_line": 158, "end_line": 162, "start_col": 0, "end_col": 91, "parent_name": null, "docstring": "Check optimal batch size for YOLO model training using autobatch utility.", "parameters": [], "return_type": null, "decorators": [ "pytest.mark.skipif(not DEVICES, reason='No CUDA devices available')" ], "complexity_score": 1, "dependencies": [ "itertools.product", "pathlib.Path", "pytest", "torch", "tests.CUDA_DEVICE_COUNT", "tests.CUDA_IS_AVAILABLE", "tests.MODEL", "tests.SOURCE", "ultralytics.YOLO", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASK2MODEL", "ultralytics.cfg.TASKS", "ultralytics.utils.ASSETS", "ultralytics.utils.IS_JETSON", "ultralytics.utils.WEIGHTS_DIR", "ultralytics.utils.autodevice.GPUInfo", "ultralytics.utils.checks.check_amp", "ultralytics.utils.torch_utils.TORCH_1_13", "os", "ultralytics.utils.autobatch.check_train_batch_size", "ultralytics.utils.benchmarks.ProfileModels", "ultralytics.SAM", "ultralytics.models.sam.Predictor" ], "chunk_id": "function_test_autobatch_b1e404dc" }, { "content": "def test_utils_benchmarks():\n \"\"\"Profile YOLO models for performance benchmarks.\"\"\"\n from ultralytics.utils.benchmarks import ProfileModels\n\n # Pre-export a dynamic engine model to use dynamic inference\n YOLO(MODEL).export(format=\"engine\", imgsz=32, dynamic=True, batch=1, device=DEVICES[0])\n ProfileModels(\n [MODEL],\n imgsz=32,\n half=False,\n min_time=1,\n num_timed_runs=3,\n num_warmup_runs=1,\n device=DEVICES[0],\n ).run()", "chunk_type": "function", "name": "test_utils_benchmarks", "file_path": "ultralytics\\tests\\test_cuda.py", "start_line": 167, "end_line": 181, "start_col": 0, "end_col": 11, "parent_name": null, "docstring": "Profile YOLO models for performance benchmarks.", "parameters": [], "return_type": null, "decorators": [ "pytest.mark.slow", "pytest.mark.skipif(not DEVICES, reason='No CUDA devices available')" ], "complexity_score": 1, "dependencies": [ "itertools.product", "pathlib.Path", "pytest", "torch", "tests.CUDA_DEVICE_COUNT", "tests.CUDA_IS_AVAILABLE", "tests.MODEL", "tests.SOURCE", "ultralytics.YOLO", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASK2MODEL", "ultralytics.cfg.TASKS", "ultralytics.utils.ASSETS", "ultralytics.utils.IS_JETSON", "ultralytics.utils.WEIGHTS_DIR", "ultralytics.utils.autodevice.GPUInfo", "ultralytics.utils.checks.check_amp", "ultralytics.utils.torch_utils.TORCH_1_13", "os", "ultralytics.utils.autobatch.check_train_batch_size", "ultralytics.utils.benchmarks.ProfileModels", "ultralytics.SAM", "ultralytics.models.sam.Predictor" ], "chunk_id": "function_test_utils_benchmarks_1810bc78" }, { "content": "def test_predict_sam():\n \"\"\"Test SAM model predictions using different prompts.\"\"\"\n from ultralytics import SAM\n from ultralytics.models.sam import Predictor as SAMPredictor\n\n model = SAM(WEIGHTS_DIR / \"sam2.1_b.pt\")\n model.info()\n\n # Run inference with various prompts\n model(SOURCE, device=DEVICES[0])\n model(SOURCE, bboxes=[439, 437, 524, 709], device=DEVICES[0])\n model(ASSETS / \"zidane.jpg\", points=[900, 370], device=DEVICES[0])\n model(ASSETS / \"zidane.jpg\", points=[900, 370], labels=[1], device=DEVICES[0])\n model(ASSETS / \"zidane.jpg\", points=[[900, 370]], labels=[1], device=DEVICES[0])\n model(ASSETS / \"zidane.jpg\", points=[[400, 370], [900, 370]], labels=[1, 1], device=DEVICES[0])\n model(ASSETS / \"zidane.jpg\", points=[[[900, 370], [1000, 100]]], labels=[[1, 1]], device=DEVICES[0])\n\n # Test predictor\n predictor = SAMPredictor(\n overrides=dict(\n conf=0.25,\n task=\"segment\",\n mode=\"predict\",\n imgsz=1024,\n model=WEIGHTS_DIR / \"mobile_sam.pt\",\n device=DEVICES[0],\n )\n )\n predictor.set_image(ASSETS / \"zidane.jpg\")\n # predictor(bboxes=[439, 437, 524, 709])\n # predictor(points=[900, 370], labels=[1])\n predictor.reset_image()", "chunk_type": "function", "name": "test_predict_sam", "file_path": "ultralytics\\tests\\test_cuda.py", "start_line": 185, "end_line": 216, "start_col": 0, "end_col": 27, "parent_name": null, "docstring": "Test SAM model predictions using different prompts.", "parameters": [], "return_type": null, "decorators": [ "pytest.mark.skipif(not DEVICES, reason='No CUDA devices available')" ], "complexity_score": 1, "dependencies": [ "itertools.product", "pathlib.Path", "pytest", "torch", "tests.CUDA_DEVICE_COUNT", "tests.CUDA_IS_AVAILABLE", "tests.MODEL", "tests.SOURCE", "ultralytics.YOLO", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASK2MODEL", "ultralytics.cfg.TASKS", "ultralytics.utils.ASSETS", "ultralytics.utils.IS_JETSON", "ultralytics.utils.WEIGHTS_DIR", "ultralytics.utils.autodevice.GPUInfo", "ultralytics.utils.checks.check_amp", "ultralytics.utils.torch_utils.TORCH_1_13", "os", "ultralytics.utils.autobatch.check_train_batch_size", "ultralytics.utils.benchmarks.ProfileModels", "ultralytics.SAM", "ultralytics.models.sam.Predictor" ], "chunk_id": "function_test_predict_sam_389c87ce" }, { "content": "import sys", "chunk_type": "import", "name": "sys", "file_path": "ultralytics\\tests\\test_engine.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 10, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_sys_413f793e" }, { "content": "from unittest import mock", "chunk_type": "import", "name": "mock", "file_path": "ultralytics\\tests\\test_engine.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 25, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_mock_11d2d6c8" }, { "content": "from tests import MODEL", "chunk_type": "import", "name": "MODEL", "file_path": "ultralytics\\tests\\test_engine.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 23, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_MODEL_3194f398" }, { "content": "from ultralytics import YOLO", "chunk_type": "import", "name": "YOLO", "file_path": "ultralytics\\tests\\test_engine.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 28, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_YOLO_87dc62dd" }, { "content": "from ultralytics.cfg import get_cfg", "chunk_type": "import", "name": "get_cfg", "file_path": "ultralytics\\tests\\test_engine.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 35, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_get_cfg_cef24eef" }, { "content": "from ultralytics.engine.exporter import Exporter", "chunk_type": "import", "name": "Exporter", "file_path": "ultralytics\\tests\\test_engine.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 48, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Exporter_9e7cb540" }, { "content": "from ultralytics.models.yolo import classify, detect, segment", "chunk_type": "import", "name": "classify, detect, segment", "file_path": "ultralytics\\tests\\test_engine.py", "start_line": 10, "end_line": 10, "start_col": 0, "end_col": 61, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_classify, detect, segment_060b7ae7" }, { "content": "from ultralytics.utils import ASSETS, DEFAULT_CFG, WEIGHTS_DIR", "chunk_type": "import", "name": "ASSETS, DEFAULT_CFG, WEIGHTS_DIR", "file_path": "ultralytics\\tests\\test_engine.py", "start_line": 11, "end_line": 11, "start_col": 0, "end_col": 62, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_ASSETS, DEFAULT_CFG, WEIGHTS_DIR_0e4dd227" }, { "content": "def test_func(*args): # noqa\n \"\"\"Test function callback for evaluating YOLO model performance metrics.\"\"\"\n print(\"callback test passed\")", "chunk_type": "function", "name": "test_func", "file_path": "ultralytics\\tests\\test_engine.py", "start_line": 14, "end_line": 16, "start_col": 0, "end_col": 33, "parent_name": null, "docstring": "Test function callback for evaluating YOLO model performance metrics.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "sys", "unittest.mock", "tests.MODEL", "ultralytics.YOLO", "ultralytics.cfg.get_cfg", "ultralytics.engine.exporter.Exporter", "ultralytics.models.yolo.classify", "ultralytics.models.yolo.detect", "ultralytics.models.yolo.segment", "ultralytics.utils.ASSETS", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.WEIGHTS_DIR" ], "chunk_id": "function_test_func_1fda8d95" }, { "content": "def test_export():\n \"\"\"Test model exporting functionality by adding a callback and verifying its execution.\"\"\"\n exporter = Exporter()\n exporter.add_callback(\"on_export_start\", test_func)\n assert test_func in exporter.callbacks[\"on_export_start\"], \"callback test failed\"\n f = exporter(model=YOLO(\"yolo11n.yaml\").model)\n YOLO(f)(ASSETS) # exported model inference", "chunk_type": "function", "name": "test_export", "file_path": "ultralytics\\tests\\test_engine.py", "start_line": 19, "end_line": 25, "start_col": 0, "end_col": 19, "parent_name": null, "docstring": "Test model exporting functionality by adding a callback and verifying its execution.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "sys", "unittest.mock", "tests.MODEL", "ultralytics.YOLO", "ultralytics.cfg.get_cfg", "ultralytics.engine.exporter.Exporter", "ultralytics.models.yolo.classify", "ultralytics.models.yolo.detect", "ultralytics.models.yolo.segment", "ultralytics.utils.ASSETS", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.WEIGHTS_DIR" ], "chunk_id": "function_test_export_97e3d98a" }, { "content": "def test_detect():\n \"\"\"Test YOLO object detection training, validation, and prediction functionality.\"\"\"\n overrides = {\"data\": \"coco8.yaml\", \"model\": \"yolo11n.yaml\", \"imgsz\": 32, \"epochs\": 1, \"save\": False}\n cfg = get_cfg(DEFAULT_CFG)\n cfg.data = \"coco8.yaml\"\n cfg.imgsz = 32\n\n # Trainer\n trainer = detect.DetectionTrainer(overrides=overrides)\n trainer.add_callback(\"on_train_start\", test_func)\n assert test_func in trainer.callbacks[\"on_train_start\"], \"callback test failed\"\n trainer.train()\n\n # Validator\n val = detect.DetectionValidator(args=cfg)\n val.add_callback(\"on_val_start\", test_func)\n assert test_func in val.callbacks[\"on_val_start\"], \"callback test failed\"\n val(model=trainer.best) # validate best.pt\n\n # Predictor\n pred = detect.DetectionPredictor(overrides={\"imgsz\": [64, 64]})\n pred.add_callback(\"on_predict_start\", test_func)\n assert test_func in pred.callbacks[\"on_predict_start\"], \"callback test failed\"\n # Confirm there is no issue with sys.argv being empty\n with mock.patch.object(sys, \"argv\", []):\n result = pred(source=ASSETS, model=MODEL)\n assert len(result), \"predictor test failed\"\n\n # Test resume functionality\n overrides[\"resume\"] = trainer.last\n trainer = detect.DetectionTrainer(overrides=overrides)\n try:\n trainer.train()\n except Exception as e:\n print(f\"Expected exception caught: {e}\")\n return\n\n raise Exception(\"Resume test failed!\")", "chunk_type": "function", "name": "test_detect", "file_path": "ultralytics\\tests\\test_engine.py", "start_line": 28, "end_line": 65, "start_col": 0, "end_col": 42, "parent_name": null, "docstring": "Test YOLO object detection training, validation, and prediction functionality.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 2, "dependencies": [ "sys", "unittest.mock", "tests.MODEL", "ultralytics.YOLO", "ultralytics.cfg.get_cfg", "ultralytics.engine.exporter.Exporter", "ultralytics.models.yolo.classify", "ultralytics.models.yolo.detect", "ultralytics.models.yolo.segment", "ultralytics.utils.ASSETS", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.WEIGHTS_DIR" ], "chunk_id": "function_test_detect_f7b535ce" }, { "content": "def test_segment():\n \"\"\"Test image segmentation training, validation, and prediction pipelines using YOLO models.\"\"\"\n overrides = {\"data\": \"coco8-seg.yaml\", \"model\": \"yolo11n-seg.yaml\", \"imgsz\": 32, \"epochs\": 1, \"save\": False}\n cfg = get_cfg(DEFAULT_CFG)\n cfg.data = \"coco8-seg.yaml\"\n cfg.imgsz = 32\n\n # Trainer\n trainer = segment.SegmentationTrainer(overrides=overrides)\n trainer.add_callback(\"on_train_start\", test_func)\n assert test_func in trainer.callbacks[\"on_train_start\"], \"callback test failed\"\n trainer.train()\n\n # Validator\n val = segment.SegmentationValidator(args=cfg)\n val.add_callback(\"on_val_start\", test_func)\n assert test_func in val.callbacks[\"on_val_start\"], \"callback test failed\"\n val(model=trainer.best) # validate best.pt\n\n # Predictor\n pred = segment.SegmentationPredictor(overrides={\"imgsz\": [64, 64]})\n pred.add_callback(\"on_predict_start\", test_func)\n assert test_func in pred.callbacks[\"on_predict_start\"], \"callback test failed\"\n result = pred(source=ASSETS, model=WEIGHTS_DIR / \"yolo11n-seg.pt\")\n assert len(result), \"predictor test failed\"\n\n # Test resume functionality\n overrides[\"resume\"] = trainer.last\n trainer = segment.SegmentationTrainer(overrides=overrides)\n try:\n trainer.train()\n except Exception as e:\n print(f\"Expected exception caught: {e}\")\n return\n\n raise Exception(\"Resume test failed!\")", "chunk_type": "function", "name": "test_segment", "file_path": "ultralytics\\tests\\test_engine.py", "start_line": 68, "end_line": 103, "start_col": 0, "end_col": 42, "parent_name": null, "docstring": "Test image segmentation training, validation, and prediction pipelines using YOLO models.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 2, "dependencies": [ "sys", "unittest.mock", "tests.MODEL", "ultralytics.YOLO", "ultralytics.cfg.get_cfg", "ultralytics.engine.exporter.Exporter", "ultralytics.models.yolo.classify", "ultralytics.models.yolo.detect", "ultralytics.models.yolo.segment", "ultralytics.utils.ASSETS", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.WEIGHTS_DIR" ], "chunk_id": "function_test_segment_dca99aae" }, { "content": "def test_classify():\n \"\"\"Test image classification including training, validation, and prediction phases.\"\"\"\n overrides = {\"data\": \"imagenet10\", \"model\": \"yolo11n-cls.yaml\", \"imgsz\": 32, \"epochs\": 1, \"save\": False}\n cfg = get_cfg(DEFAULT_CFG)\n cfg.data = \"imagenet10\"\n cfg.imgsz = 32\n\n # Trainer\n trainer = classify.ClassificationTrainer(overrides=overrides)\n trainer.add_callback(\"on_train_start\", test_func)\n assert test_func in trainer.callbacks[\"on_train_start\"], \"callback test failed\"\n trainer.train()\n\n # Validator\n val = classify.ClassificationValidator(args=cfg)\n val.add_callback(\"on_val_start\", test_func)\n assert test_func in val.callbacks[\"on_val_start\"], \"callback test failed\"\n val(model=trainer.best)\n\n # Predictor\n pred = classify.ClassificationPredictor(overrides={\"imgsz\": [64, 64]})\n pred.add_callback(\"on_predict_start\", test_func)\n assert test_func in pred.callbacks[\"on_predict_start\"], \"callback test failed\"\n result = pred(source=ASSETS, model=trainer.best)\n assert len(result), \"predictor test failed\"", "chunk_type": "function", "name": "test_classify", "file_path": "ultralytics\\tests\\test_engine.py", "start_line": 106, "end_line": 130, "start_col": 0, "end_col": 47, "parent_name": null, "docstring": "Test image classification including training, validation, and prediction phases.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "sys", "unittest.mock", "tests.MODEL", "ultralytics.YOLO", "ultralytics.cfg.get_cfg", "ultralytics.engine.exporter.Exporter", "ultralytics.models.yolo.classify", "ultralytics.models.yolo.detect", "ultralytics.models.yolo.segment", "ultralytics.utils.ASSETS", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.WEIGHTS_DIR" ], "chunk_id": "function_test_classify_3f48c9fe" }, { "content": "import io", "chunk_type": "import", "name": "io", "file_path": "ultralytics\\tests\\test_exports.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 9, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_io_6d5015cc" }, { "content": "import shutil", "chunk_type": "import", "name": "shutil", "file_path": "ultralytics\\tests\\test_exports.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 13, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_shutil_9da7147c" }, { "content": "import uuid", "chunk_type": "import", "name": "uuid", "file_path": "ultralytics\\tests\\test_exports.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 11, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_uuid_5fc37519" }, { "content": "from contextlib import redirect_stderr, redirect_stdout", "chunk_type": "import", "name": "redirect_stderr, redirect_stdout", "file_path": "ultralytics\\tests\\test_exports.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 55, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_redirect_stderr, redirect_stdout_98cf5ed3" }, { "content": "from itertools import product", "chunk_type": "import", "name": "product", "file_path": "ultralytics\\tests\\test_exports.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 29, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_product_57acd384" }, { "content": "from pathlib import Path", "chunk_type": "import", "name": "Path", "file_path": "ultralytics\\tests\\test_exports.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Path_3d748ab8" }, { "content": "import pytest", "chunk_type": "import", "name": "pytest", "file_path": "ultralytics\\tests\\test_exports.py", "start_line": 10, "end_line": 10, "start_col": 0, "end_col": 13, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_pytest_b5db7ebc" }, { "content": "from tests import MODEL, SOURCE", "chunk_type": "import", "name": "MODEL, SOURCE", "file_path": "ultralytics\\tests\\test_exports.py", "start_line": 12, "end_line": 12, "start_col": 0, "end_col": 31, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_MODEL, SOURCE_419f039b" }, { "content": "from ultralytics import YOLO", "chunk_type": "import", "name": "YOLO", "file_path": "ultralytics\\tests\\test_exports.py", "start_line": 13, "end_line": 13, "start_col": 0, "end_col": 28, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_YOLO_716fa55e" }, { "content": "from ultralytics.cfg import TASK2DATA, TASK2MODEL, TASKS", "chunk_type": "import", "name": "TASK2DATA, TASK2MODEL, TASKS", "file_path": "ultralytics\\tests\\test_exports.py", "start_line": 14, "end_line": 14, "start_col": 0, "end_col": 56, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_TASK2DATA, TASK2MODEL, TASKS_aad5f5b5" }, { "content": "from ultralytics.utils import (\n ARM64,\n IS_RASPBERRYPI,\n LINUX,\n MACOS,\n WINDOWS,\n checks,\n)", "chunk_type": "import", "name": "ARM64, IS_RASPBERRYPI, LINUX, MACOS, WINDOWS, checks", "file_path": "ultralytics\\tests\\test_exports.py", "start_line": 15, "end_line": 22, "start_col": 0, "end_col": 1, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_ARM64, IS_RASPBERRYPI, LINUX, MACOS, WINDOWS, checks_5a9da276" }, { "content": "from ultralytics.utils.torch_utils import TORCH_1_9, TORCH_1_13", "chunk_type": "import", "name": "TORCH_1_9, TORCH_1_13", "file_path": "ultralytics\\tests\\test_exports.py", "start_line": 23, "end_line": 23, "start_col": 0, "end_col": 63, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_TORCH_1_9, TORCH_1_13_8b790b09" }, { "content": "def test_export_torchscript():\n \"\"\"Test YOLO model export to TorchScript format for compatibility and correctness.\"\"\"\n file = YOLO(MODEL).export(format=\"torchscript\", optimize=False, imgsz=32)\n YOLO(file)(SOURCE, imgsz=32) # exported model inference", "chunk_type": "function", "name": "test_export_torchscript", "file_path": "ultralytics\\tests\\test_exports.py", "start_line": 26, "end_line": 29, "start_col": 0, "end_col": 32, "parent_name": null, "docstring": "Test YOLO model export to TorchScript format for compatibility and correctness.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "io", "shutil", "uuid", "contextlib.redirect_stderr", "contextlib.redirect_stdout", "itertools.product", "pathlib.Path", "pytest", "tests.MODEL", "tests.SOURCE", "ultralytics.YOLO", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASK2MODEL", "ultralytics.cfg.TASKS", "ultralytics.utils.ARM64", "ultralytics.utils.IS_RASPBERRYPI", "ultralytics.utils.LINUX", "ultralytics.utils.MACOS", "ultralytics.utils.WINDOWS", "ultralytics.utils.checks", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.utils.torch_utils.TORCH_1_13" ], "chunk_id": "function_test_export_torchscript_adeffecc" }, { "content": "def test_export_onnx():\n \"\"\"Test YOLO model export to ONNX format with dynamic axes.\"\"\"\n file = YOLO(MODEL).export(format=\"onnx\", dynamic=True, imgsz=32)\n YOLO(file)(SOURCE, imgsz=32) # exported model inference", "chunk_type": "function", "name": "test_export_onnx", "file_path": "ultralytics\\tests\\test_exports.py", "start_line": 32, "end_line": 35, "start_col": 0, "end_col": 32, "parent_name": null, "docstring": "Test YOLO model export to ONNX format with dynamic axes.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "io", "shutil", "uuid", "contextlib.redirect_stderr", "contextlib.redirect_stdout", "itertools.product", "pathlib.Path", "pytest", "tests.MODEL", "tests.SOURCE", "ultralytics.YOLO", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASK2MODEL", "ultralytics.cfg.TASKS", "ultralytics.utils.ARM64", "ultralytics.utils.IS_RASPBERRYPI", "ultralytics.utils.LINUX", "ultralytics.utils.MACOS", "ultralytics.utils.WINDOWS", "ultralytics.utils.checks", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.utils.torch_utils.TORCH_1_13" ], "chunk_id": "function_test_export_onnx_7d34bd44" }, { "content": "def test_export_openvino():\n \"\"\"Test YOLO export to OpenVINO format for model inference compatibility.\"\"\"\n file = YOLO(MODEL).export(format=\"openvino\", imgsz=32)\n YOLO(file)(SOURCE, imgsz=32) # exported model inference", "chunk_type": "function", "name": "test_export_openvino", "file_path": "ultralytics\\tests\\test_exports.py", "start_line": 39, "end_line": 42, "start_col": 0, "end_col": 32, "parent_name": null, "docstring": "Test YOLO export to OpenVINO format for model inference compatibility.", "parameters": [], "return_type": null, "decorators": [ "pytest.mark.skipif(not TORCH_1_13, reason='OpenVINO requires torch>=1.13')" ], "complexity_score": 1, "dependencies": [ "io", "shutil", "uuid", "contextlib.redirect_stderr", "contextlib.redirect_stdout", "itertools.product", "pathlib.Path", "pytest", "tests.MODEL", "tests.SOURCE", "ultralytics.YOLO", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASK2MODEL", "ultralytics.cfg.TASKS", "ultralytics.utils.ARM64", "ultralytics.utils.IS_RASPBERRYPI", "ultralytics.utils.LINUX", "ultralytics.utils.MACOS", "ultralytics.utils.WINDOWS", "ultralytics.utils.checks", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.utils.torch_utils.TORCH_1_13" ], "chunk_id": "function_test_export_openvino_ece5c563" }, { "content": "def test_export_openvino_matrix(task, dynamic, int8, half, batch, nms):\n \"\"\"Test YOLO model export to OpenVINO under various configuration matrix conditions.\"\"\"\n file = YOLO(TASK2MODEL[task]).export(\n format=\"openvino\",\n imgsz=32,\n dynamic=dynamic,\n int8=int8,\n half=half,\n batch=batch,\n data=TASK2DATA[task],\n nms=nms,\n )\n if WINDOWS:\n # Use unique filenames due to Windows file permissions bug possibly due to latent threaded use\n # See https://github.com/ultralytics/ultralytics/actions/runs/8957949304/job/24601616830?pr=10423\n file = Path(file)\n file = file.rename(file.with_stem(f\"{file.stem}-{uuid.uuid4()}\"))\n YOLO(file)([SOURCE] * batch, imgsz=64 if dynamic else 32) # exported model inference\n shutil.rmtree(file, ignore_errors=True) # retry in case of potential lingering multi-threaded file usage errors", "chunk_type": "function", "name": "test_export_openvino_matrix", "file_path": "ultralytics\\tests\\test_exports.py", "start_line": 57, "end_line": 75, "start_col": 0, "end_col": 43, "parent_name": null, "docstring": "Test YOLO model export to OpenVINO under various configuration matrix conditions.", "parameters": [ "task", "dynamic", "int8", "half", "batch", "nms" ], "return_type": null, "decorators": [ "pytest.mark.slow", "pytest.mark.skipif(not TORCH_1_13, reason='OpenVINO requires torch>=1.13')", "pytest.mark.parametrize('task, dynamic, int8, half, batch, nms', [(task, dynamic, int8, half, batch, nms) for task, dynamic, int8, half, batch, nms in product(TASKS, [True, False], [True, False], [True, False], [1, 2], [True, False]) if not (int8 and half or (task == 'classify' and nms))])" ], "complexity_score": 3, "dependencies": [ "io", "shutil", "uuid", "contextlib.redirect_stderr", "contextlib.redirect_stdout", "itertools.product", "pathlib.Path", "pytest", "tests.MODEL", "tests.SOURCE", "ultralytics.YOLO", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASK2MODEL", "ultralytics.cfg.TASKS", "ultralytics.utils.ARM64", "ultralytics.utils.IS_RASPBERRYPI", "ultralytics.utils.LINUX", "ultralytics.utils.MACOS", "ultralytics.utils.WINDOWS", "ultralytics.utils.checks", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.utils.torch_utils.TORCH_1_13" ], "chunk_id": "function_test_export_openvino_matrix_bc80f24f" }, { "content": "def test_export_onnx_matrix(task, dynamic, int8, half, batch, simplify, nms):\n \"\"\"Test YOLO export to ONNX format with various configurations and parameters.\"\"\"\n file = YOLO(TASK2MODEL[task]).export(\n format=\"onnx\", imgsz=32, dynamic=dynamic, int8=int8, half=half, batch=batch, simplify=simplify, nms=nms\n )\n YOLO(file)([SOURCE] * batch, imgsz=64 if dynamic else 32) # exported model inference\n Path(file).unlink() # cleanup", "chunk_type": "function", "name": "test_export_onnx_matrix", "file_path": "ultralytics\\tests\\test_exports.py", "start_line": 89, "end_line": 95, "start_col": 0, "end_col": 23, "parent_name": null, "docstring": "Test YOLO export to ONNX format with various configurations and parameters.", "parameters": [ "task", "dynamic", "int8", "half", "batch", "simplify", "nms" ], "return_type": null, "decorators": [ "pytest.mark.slow", "pytest.mark.parametrize('task, dynamic, int8, half, batch, simplify, nms', [(task, dynamic, int8, half, batch, simplify, nms) for task, dynamic, int8, half, batch, simplify, nms in product(TASKS, [True, False], [False], [False], [1, 2], [True, False], [True, False]) if not (int8 and half or (task == 'classify' and nms) or (task == 'obb' and nms and (not TORCH_1_13)))])" ], "complexity_score": 2, "dependencies": [ "io", "shutil", "uuid", "contextlib.redirect_stderr", "contextlib.redirect_stdout", "itertools.product", "pathlib.Path", "pytest", "tests.MODEL", "tests.SOURCE", "ultralytics.YOLO", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASK2MODEL", "ultralytics.cfg.TASKS", "ultralytics.utils.ARM64", "ultralytics.utils.IS_RASPBERRYPI", "ultralytics.utils.LINUX", "ultralytics.utils.MACOS", "ultralytics.utils.WINDOWS", "ultralytics.utils.checks", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.utils.torch_utils.TORCH_1_13" ], "chunk_id": "function_test_export_onnx_matrix_8d0b3fdd" }, { "content": "def test_export_torchscript_matrix(task, dynamic, int8, half, batch, nms):\n \"\"\"Test YOLO model export to TorchScript format under varied configurations.\"\"\"\n file = YOLO(TASK2MODEL[task]).export(\n format=\"torchscript\", imgsz=32, dynamic=dynamic, int8=int8, half=half, batch=batch, nms=nms\n )\n YOLO(file)([SOURCE] * batch, imgsz=64 if dynamic else 32) # exported model inference\n Path(file).unlink() # cleanup", "chunk_type": "function", "name": "test_export_torchscript_matrix", "file_path": "ultralytics\\tests\\test_exports.py", "start_line": 107, "end_line": 113, "start_col": 0, "end_col": 23, "parent_name": null, "docstring": "Test YOLO model export to TorchScript format under varied configurations.", "parameters": [ "task", "dynamic", "int8", "half", "batch", "nms" ], "return_type": null, "decorators": [ "pytest.mark.slow", "pytest.mark.parametrize('task, dynamic, int8, half, batch, nms', [(task, dynamic, int8, half, batch, nms) for task, dynamic, int8, half, batch, nms in product(TASKS, [False], [False], [False], [1, 2], [True, False]) if not (task == 'classify' and nms)])" ], "complexity_score": 2, "dependencies": [ "io", "shutil", "uuid", "contextlib.redirect_stderr", "contextlib.redirect_stdout", "itertools.product", "pathlib.Path", "pytest", "tests.MODEL", "tests.SOURCE", "ultralytics.YOLO", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASK2MODEL", "ultralytics.cfg.TASKS", "ultralytics.utils.ARM64", "ultralytics.utils.IS_RASPBERRYPI", "ultralytics.utils.LINUX", "ultralytics.utils.MACOS", "ultralytics.utils.WINDOWS", "ultralytics.utils.checks", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.utils.torch_utils.TORCH_1_13" ], "chunk_id": "function_test_export_torchscript_matrix_4b9caa47" }, { "content": "def test_export_coreml_matrix(task, dynamic, int8, half, batch):\n \"\"\"Test YOLO export to CoreML format with various parameter configurations.\"\"\"\n file = YOLO(TASK2MODEL[task]).export(\n format=\"coreml\",\n imgsz=32,\n dynamic=dynamic,\n int8=int8,\n half=half,\n batch=batch,\n )\n YOLO(file)([SOURCE] * batch, imgsz=32) # exported model inference\n shutil.rmtree(file) # cleanup", "chunk_type": "function", "name": "test_export_coreml_matrix", "file_path": "ultralytics\\tests\\test_exports.py", "start_line": 128, "end_line": 139, "start_col": 0, "end_col": 23, "parent_name": null, "docstring": "Test YOLO export to CoreML format with various parameter configurations.", "parameters": [ "task", "dynamic", "int8", "half", "batch" ], "return_type": null, "decorators": [ "pytest.mark.slow", "pytest.mark.skipif(not MACOS, reason='CoreML inference only supported on macOS')", "pytest.mark.skipif(not TORCH_1_9, reason='CoreML>=7.2 not supported with PyTorch<=1.8')", "pytest.mark.skipif(checks.IS_PYTHON_3_13, reason='CoreML not supported in Python 3.13')", "pytest.mark.parametrize('task, dynamic, int8, half, batch', [(task, dynamic, int8, half, batch) for task, dynamic, int8, half, batch in product(TASKS, [False], [True, False], [True, False], [1]) if not (int8 and half)])" ], "complexity_score": 2, "dependencies": [ "io", "shutil", "uuid", "contextlib.redirect_stderr", "contextlib.redirect_stdout", "itertools.product", "pathlib.Path", "pytest", "tests.MODEL", "tests.SOURCE", "ultralytics.YOLO", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASK2MODEL", "ultralytics.cfg.TASKS", "ultralytics.utils.ARM64", "ultralytics.utils.IS_RASPBERRYPI", "ultralytics.utils.LINUX", "ultralytics.utils.MACOS", "ultralytics.utils.WINDOWS", "ultralytics.utils.checks", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.utils.torch_utils.TORCH_1_13" ], "chunk_id": "function_test_export_coreml_matrix_44ef060d" }, { "content": "def test_export_tflite_matrix(task, dynamic, int8, half, batch, nms):\n \"\"\"Test YOLO export to TFLite format considering various export configurations.\"\"\"\n file = YOLO(TASK2MODEL[task]).export(\n format=\"tflite\", imgsz=32, dynamic=dynamic, int8=int8, half=half, batch=batch, nms=nms\n )\n YOLO(file)([SOURCE] * batch, imgsz=32) # exported model inference\n Path(file).unlink() # cleanup", "chunk_type": "function", "name": "test_export_tflite_matrix", "file_path": "ultralytics\\tests\\test_exports.py", "start_line": 158, "end_line": 164, "start_col": 0, "end_col": 23, "parent_name": null, "docstring": "Test YOLO export to TFLite format considering various export configurations.", "parameters": [ "task", "dynamic", "int8", "half", "batch", "nms" ], "return_type": null, "decorators": [ "pytest.mark.slow", "pytest.mark.skipif(not checks.IS_PYTHON_MINIMUM_3_10, reason='TFLite export requires Python>=3.10')", "pytest.mark.skipif(not LINUX or IS_RASPBERRYPI, reason='Test disabled as TF suffers from install conflicts on Windows, macOS and Raspberry Pi')", "pytest.mark.parametrize('task, dynamic, int8, half, batch, nms', [(task, dynamic, int8, half, batch, nms) for task, dynamic, int8, half, batch, nms in product(TASKS, [False], [True, False], [True, False], [1], [True, False]) if not (int8 and half or (task == 'classify' and nms) or (ARM64 and nms))])" ], "complexity_score": 2, "dependencies": [ "io", "shutil", "uuid", "contextlib.redirect_stderr", "contextlib.redirect_stdout", "itertools.product", "pathlib.Path", "pytest", "tests.MODEL", "tests.SOURCE", "ultralytics.YOLO", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASK2MODEL", "ultralytics.cfg.TASKS", "ultralytics.utils.ARM64", "ultralytics.utils.IS_RASPBERRYPI", "ultralytics.utils.LINUX", "ultralytics.utils.MACOS", "ultralytics.utils.WINDOWS", "ultralytics.utils.checks", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.utils.torch_utils.TORCH_1_13" ], "chunk_id": "function_test_export_tflite_matrix_fa126a03" }, { "content": "def test_export_coreml():\n \"\"\"Test YOLO export to CoreML format and check for errors.\"\"\"\n # Capture stdout and stderr\n stdout, stderr = io.StringIO(), io.StringIO()\n with redirect_stdout(stdout), redirect_stderr(stderr):\n YOLO(MODEL).export(format=\"coreml\", nms=True, imgsz=32)\n if MACOS:\n file = YOLO(MODEL).export(format=\"coreml\", imgsz=32)\n YOLO(file)(SOURCE, imgsz=32) # model prediction only supported on macOS for nms=False models\n\n # Check captured output for errors\n output = stdout.getvalue() + stderr.getvalue()\n assert \"Error\" not in output, f\"CoreML export produced errors: {output}\"\n assert \"You will not be able to run predict()\" not in output, \"CoreML export has predict() error\"", "chunk_type": "function", "name": "test_export_coreml", "file_path": "ultralytics\\tests\\test_exports.py", "start_line": 171, "end_line": 184, "start_col": 0, "end_col": 101, "parent_name": null, "docstring": "Test YOLO export to CoreML format and check for errors.", "parameters": [], "return_type": null, "decorators": [ "pytest.mark.skipif(not TORCH_1_9, reason='CoreML>=7.2 not supported with PyTorch<=1.8')", "pytest.mark.skipif(WINDOWS, reason='CoreML not supported on Windows')", "pytest.mark.skipif(LINUX and ARM64, reason='CoreML not supported on aarch64 Linux')", "pytest.mark.skipif(checks.IS_PYTHON_3_13, reason='CoreML not supported in Python 3.13')" ], "complexity_score": 2, "dependencies": [ "io", "shutil", "uuid", "contextlib.redirect_stderr", "contextlib.redirect_stdout", "itertools.product", "pathlib.Path", "pytest", "tests.MODEL", "tests.SOURCE", "ultralytics.YOLO", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASK2MODEL", "ultralytics.cfg.TASKS", "ultralytics.utils.ARM64", "ultralytics.utils.IS_RASPBERRYPI", "ultralytics.utils.LINUX", "ultralytics.utils.MACOS", "ultralytics.utils.WINDOWS", "ultralytics.utils.checks", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.utils.torch_utils.TORCH_1_13" ], "chunk_id": "function_test_export_coreml_b1a1c800" }, { "content": "def test_export_tflite():\n \"\"\"Test YOLO export to TFLite format under specific OS and Python version conditions.\"\"\"\n model = YOLO(MODEL)\n file = model.export(format=\"tflite\", imgsz=32)\n YOLO(file)(SOURCE, imgsz=32)", "chunk_type": "function", "name": "test_export_tflite", "file_path": "ultralytics\\tests\\test_exports.py", "start_line": 189, "end_line": 193, "start_col": 0, "end_col": 32, "parent_name": null, "docstring": "Test YOLO export to TFLite format under specific OS and Python version conditions.", "parameters": [], "return_type": null, "decorators": [ "pytest.mark.skipif(not checks.IS_PYTHON_MINIMUM_3_10, reason='TFLite export requires Python>=3.10')", "pytest.mark.skipif(not LINUX, reason='Test disabled as TF suffers from install conflicts on Windows and macOS')" ], "complexity_score": 1, "dependencies": [ "io", "shutil", "uuid", "contextlib.redirect_stderr", "contextlib.redirect_stdout", "itertools.product", "pathlib.Path", "pytest", "tests.MODEL", "tests.SOURCE", "ultralytics.YOLO", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASK2MODEL", "ultralytics.cfg.TASKS", "ultralytics.utils.ARM64", "ultralytics.utils.IS_RASPBERRYPI", "ultralytics.utils.LINUX", "ultralytics.utils.MACOS", "ultralytics.utils.WINDOWS", "ultralytics.utils.checks", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.utils.torch_utils.TORCH_1_13" ], "chunk_id": "function_test_export_tflite_7bffcd60" }, { "content": "def test_export_pb():\n \"\"\"Test YOLO export to TensorFlow's Protobuf (*.pb) format.\"\"\"\n model = YOLO(MODEL)\n file = model.export(format=\"pb\", imgsz=32)\n YOLO(file)(SOURCE, imgsz=32)", "chunk_type": "function", "name": "test_export_pb", "file_path": "ultralytics\\tests\\test_exports.py", "start_line": 198, "end_line": 202, "start_col": 0, "end_col": 32, "parent_name": null, "docstring": "Test YOLO export to TensorFlow's Protobuf (*.pb) format.", "parameters": [], "return_type": null, "decorators": [ "pytest.mark.skipif(True, reason='Test disabled')", "pytest.mark.skipif(not LINUX, reason='TF suffers from install conflicts on Windows and macOS')" ], "complexity_score": 1, "dependencies": [ "io", "shutil", "uuid", "contextlib.redirect_stderr", "contextlib.redirect_stdout", "itertools.product", "pathlib.Path", "pytest", "tests.MODEL", "tests.SOURCE", "ultralytics.YOLO", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASK2MODEL", "ultralytics.cfg.TASKS", "ultralytics.utils.ARM64", "ultralytics.utils.IS_RASPBERRYPI", "ultralytics.utils.LINUX", "ultralytics.utils.MACOS", "ultralytics.utils.WINDOWS", "ultralytics.utils.checks", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.utils.torch_utils.TORCH_1_13" ], "chunk_id": "function_test_export_pb_daf73fbf" }, { "content": "def test_export_paddle():\n \"\"\"Test YOLO export to Paddle format, noting protobuf conflicts with ONNX.\"\"\"\n YOLO(MODEL).export(format=\"paddle\", imgsz=32)", "chunk_type": "function", "name": "test_export_paddle", "file_path": "ultralytics\\tests\\test_exports.py", "start_line": 206, "end_line": 208, "start_col": 0, "end_col": 49, "parent_name": null, "docstring": "Test YOLO export to Paddle format, noting protobuf conflicts with ONNX.", "parameters": [], "return_type": null, "decorators": [ "pytest.mark.skipif(True, reason='Test disabled as Paddle protobuf and ONNX protobuf requirements conflict.')" ], "complexity_score": 1, "dependencies": [ "io", "shutil", "uuid", "contextlib.redirect_stderr", "contextlib.redirect_stdout", "itertools.product", "pathlib.Path", "pytest", "tests.MODEL", "tests.SOURCE", "ultralytics.YOLO", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASK2MODEL", "ultralytics.cfg.TASKS", "ultralytics.utils.ARM64", "ultralytics.utils.IS_RASPBERRYPI", "ultralytics.utils.LINUX", "ultralytics.utils.MACOS", "ultralytics.utils.WINDOWS", "ultralytics.utils.checks", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.utils.torch_utils.TORCH_1_13" ], "chunk_id": "function_test_export_paddle_6ad3bb42" }, { "content": "def test_export_mnn():\n \"\"\"Test YOLO export to MNN format (WARNING: MNN test must precede NCNN test or CI error on Windows).\"\"\"\n file = YOLO(MODEL).export(format=\"mnn\", imgsz=32)\n YOLO(file)(SOURCE, imgsz=32) # exported model inference", "chunk_type": "function", "name": "test_export_mnn", "file_path": "ultralytics\\tests\\test_exports.py", "start_line": 212, "end_line": 215, "start_col": 0, "end_col": 32, "parent_name": null, "docstring": "Test YOLO export to MNN format (WARNING: MNN test must precede NCNN test or CI error on Windows).", "parameters": [], "return_type": null, "decorators": [ "pytest.mark.slow" ], "complexity_score": 1, "dependencies": [ "io", "shutil", "uuid", "contextlib.redirect_stderr", "contextlib.redirect_stdout", "itertools.product", "pathlib.Path", "pytest", "tests.MODEL", "tests.SOURCE", "ultralytics.YOLO", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASK2MODEL", "ultralytics.cfg.TASKS", "ultralytics.utils.ARM64", "ultralytics.utils.IS_RASPBERRYPI", "ultralytics.utils.LINUX", "ultralytics.utils.MACOS", "ultralytics.utils.WINDOWS", "ultralytics.utils.checks", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.utils.torch_utils.TORCH_1_13" ], "chunk_id": "function_test_export_mnn_37890eb8" }, { "content": "def test_export_ncnn():\n \"\"\"Test YOLO export to NCNN format.\"\"\"\n file = YOLO(MODEL).export(format=\"ncnn\", imgsz=32)\n YOLO(file)(SOURCE, imgsz=32) # exported model inference", "chunk_type": "function", "name": "test_export_ncnn", "file_path": "ultralytics\\tests\\test_exports.py", "start_line": 219, "end_line": 222, "start_col": 0, "end_col": 32, "parent_name": null, "docstring": "Test YOLO export to NCNN format.", "parameters": [], "return_type": null, "decorators": [ "pytest.mark.slow" ], "complexity_score": 1, "dependencies": [ "io", "shutil", "uuid", "contextlib.redirect_stderr", "contextlib.redirect_stdout", "itertools.product", "pathlib.Path", "pytest", "tests.MODEL", "tests.SOURCE", "ultralytics.YOLO", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASK2MODEL", "ultralytics.cfg.TASKS", "ultralytics.utils.ARM64", "ultralytics.utils.IS_RASPBERRYPI", "ultralytics.utils.LINUX", "ultralytics.utils.MACOS", "ultralytics.utils.WINDOWS", "ultralytics.utils.checks", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.utils.torch_utils.TORCH_1_13" ], "chunk_id": "function_test_export_ncnn_3b486a18" }, { "content": "def test_export_imx():\n \"\"\"Test YOLO export to IMX format.\"\"\"\n model = YOLO(\"yolov8n.pt\")\n file = model.export(format=\"imx\", imgsz=32)\n YOLO(file)(SOURCE, imgsz=32)", "chunk_type": "function", "name": "test_export_imx", "file_path": "ultralytics\\tests\\test_exports.py", "start_line": 227, "end_line": 231, "start_col": 0, "end_col": 32, "parent_name": null, "docstring": "Test YOLO export to IMX format.", "parameters": [], "return_type": null, "decorators": [ "pytest.mark.skipif(True, reason='Test disabled as keras and tensorflow version conflicts with TFlite export.')", "pytest.mark.skipif(not LINUX or MACOS, reason='Skipping test on Windows and Macos')" ], "complexity_score": 1, "dependencies": [ "io", "shutil", "uuid", "contextlib.redirect_stderr", "contextlib.redirect_stdout", "itertools.product", "pathlib.Path", "pytest", "tests.MODEL", "tests.SOURCE", "ultralytics.YOLO", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASK2MODEL", "ultralytics.cfg.TASKS", "ultralytics.utils.ARM64", "ultralytics.utils.IS_RASPBERRYPI", "ultralytics.utils.LINUX", "ultralytics.utils.MACOS", "ultralytics.utils.WINDOWS", "ultralytics.utils.checks", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.utils.torch_utils.TORCH_1_13" ], "chunk_id": "function_test_export_imx_e54fbc02" }, { "content": "import contextlib", "chunk_type": "import", "name": "contextlib", "file_path": "ultralytics\\tests\\test_integrations.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 17, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_contextlib_1a6f7e1a" }, { "content": "import os", "chunk_type": "import", "name": "os", "file_path": "ultralytics\\tests\\test_integrations.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 9, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_os_555493c3" }, { "content": "import subprocess", "chunk_type": "import", "name": "subprocess", "file_path": "ultralytics\\tests\\test_integrations.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 17, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_subprocess_81d8c871" }, { "content": "import time", "chunk_type": "import", "name": "time", "file_path": "ultralytics\\tests\\test_integrations.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 11, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_time_378f92b7" }, { "content": "from pathlib import Path", "chunk_type": "import", "name": "Path", "file_path": "ultralytics\\tests\\test_integrations.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Path_935427ed" }, { "content": "import pytest", "chunk_type": "import", "name": "pytest", "file_path": "ultralytics\\tests\\test_integrations.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 13, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_pytest_8fcd72c2" }, { "content": "from tests import MODEL, SOURCE, TMP", "chunk_type": "import", "name": "MODEL, SOURCE, TMP", "file_path": "ultralytics\\tests\\test_integrations.py", "start_line": 11, "end_line": 11, "start_col": 0, "end_col": 36, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_MODEL, SOURCE, TMP_339cbbdd" }, { "content": "from ultralytics import YOLO, download", "chunk_type": "import", "name": "YOLO, download", "file_path": "ultralytics\\tests\\test_integrations.py", "start_line": 12, "end_line": 12, "start_col": 0, "end_col": 38, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_YOLO, download_b5c4802a" }, { "content": "from ultralytics.utils import DATASETS_DIR, SETTINGS", "chunk_type": "import", "name": "DATASETS_DIR, SETTINGS", "file_path": "ultralytics\\tests\\test_integrations.py", "start_line": 13, "end_line": 13, "start_col": 0, "end_col": 52, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_DATASETS_DIR, SETTINGS_28a35f76" }, { "content": "from ultralytics.utils.checks import check_requirements", "chunk_type": "import", "name": "check_requirements", "file_path": "ultralytics\\tests\\test_integrations.py", "start_line": 14, "end_line": 14, "start_col": 0, "end_col": 55, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_check_requirements_7d922aae" }, { "content": "def test_tensorboard():\n \"\"\"Test training with TensorBoard logging enabled.\"\"\"\n SETTINGS[\"tensorboard\"] = True\n YOLO(\"yolo11n-cls.yaml\").train(data=\"imagenet10\", imgsz=32, epochs=3, plots=False, device=\"cpu\")\n SETTINGS[\"tensorboard\"] = False", "chunk_type": "function", "name": "test_tensorboard", "file_path": "ultralytics\\tests\\test_integrations.py", "start_line": 18, "end_line": 22, "start_col": 0, "end_col": 35, "parent_name": null, "docstring": "Test training with TensorBoard logging enabled.", "parameters": [], "return_type": null, "decorators": [ "pytest.mark.slow" ], "complexity_score": 1, "dependencies": [ "contextlib", "os", "subprocess", "time", "pathlib.Path", "pytest", "tests.MODEL", "tests.SOURCE", "tests.TMP", "ultralytics.YOLO", "ultralytics.download", "ultralytics.utils.DATASETS_DIR", "ultralytics.utils.SETTINGS", "ultralytics.utils.checks.check_requirements", "mlflow", "tritonclient.http.InferenceServerClient", "ultralytics.models.yolo.detect.DetectionValidator", "ultralytics.models.yolo.pose.PoseValidator", "ultralytics.models.yolo.segment.SegmentationValidator" ], "chunk_id": "function_test_tensorboard_77be5976" }, { "content": "def test_model_ray_tune():\n \"\"\"Tune YOLO model using Ray for hyperparameter optimization.\"\"\"\n YOLO(\"yolo11n-cls.yaml\").tune(\n use_ray=True, data=\"imagenet10\", grace_period=1, iterations=1, imgsz=32, epochs=1, plots=False, device=\"cpu\"\n )", "chunk_type": "function", "name": "test_model_ray_tune", "file_path": "ultralytics\\tests\\test_integrations.py", "start_line": 26, "end_line": 30, "start_col": 0, "end_col": 5, "parent_name": null, "docstring": "Tune YOLO model using Ray for hyperparameter optimization.", "parameters": [], "return_type": null, "decorators": [ "pytest.mark.skipif(not check_requirements('ray', install=False), reason='ray[tune] not installed')" ], "complexity_score": 1, "dependencies": [ "contextlib", "os", "subprocess", "time", "pathlib.Path", "pytest", "tests.MODEL", "tests.SOURCE", "tests.TMP", "ultralytics.YOLO", "ultralytics.download", "ultralytics.utils.DATASETS_DIR", "ultralytics.utils.SETTINGS", "ultralytics.utils.checks.check_requirements", "mlflow", "tritonclient.http.InferenceServerClient", "ultralytics.models.yolo.detect.DetectionValidator", "ultralytics.models.yolo.pose.PoseValidator", "ultralytics.models.yolo.segment.SegmentationValidator" ], "chunk_id": "function_test_model_ray_tune_c9412d64" }, { "content": "def test_mlflow():\n \"\"\"Test training with MLflow tracking enabled.\"\"\"\n SETTINGS[\"mlflow\"] = True\n YOLO(\"yolo11n-cls.yaml\").train(data=\"imagenet10\", imgsz=32, epochs=3, plots=False, device=\"cpu\")\n SETTINGS[\"mlflow\"] = False", "chunk_type": "function", "name": "test_mlflow", "file_path": "ultralytics\\tests\\test_integrations.py", "start_line": 34, "end_line": 38, "start_col": 0, "end_col": 30, "parent_name": null, "docstring": "Test training with MLflow tracking enabled.", "parameters": [], "return_type": null, "decorators": [ "pytest.mark.skipif(not check_requirements('mlflow', install=False), reason='mlflow not installed')" ], "complexity_score": 1, "dependencies": [ "contextlib", "os", "subprocess", "time", "pathlib.Path", "pytest", "tests.MODEL", "tests.SOURCE", "tests.TMP", "ultralytics.YOLO", "ultralytics.download", "ultralytics.utils.DATASETS_DIR", "ultralytics.utils.SETTINGS", "ultralytics.utils.checks.check_requirements", "mlflow", "tritonclient.http.InferenceServerClient", "ultralytics.models.yolo.detect.DetectionValidator", "ultralytics.models.yolo.pose.PoseValidator", "ultralytics.models.yolo.segment.SegmentationValidator" ], "chunk_id": "function_test_mlflow_40323fae" }, { "content": "def test_mlflow_keep_run_active():\n \"\"\"Ensure MLflow run status matches MLFLOW_KEEP_RUN_ACTIVE environment variable settings.\"\"\"\n import mlflow\n\n SETTINGS[\"mlflow\"] = True\n run_name = \"Test Run\"\n os.environ[\"MLFLOW_RUN\"] = run_name\n\n # Test with MLFLOW_KEEP_RUN_ACTIVE=True\n os.environ[\"MLFLOW_KEEP_RUN_ACTIVE\"] = \"True\"\n YOLO(\"yolo11n-cls.yaml\").train(data=\"imagenet10\", imgsz=32, epochs=1, plots=False, device=\"cpu\")\n status = mlflow.active_run().info.status\n assert status == \"RUNNING\", \"MLflow run should be active when MLFLOW_KEEP_RUN_ACTIVE=True\"\n\n run_id = mlflow.active_run().info.run_id\n\n # Test with MLFLOW_KEEP_RUN_ACTIVE=False\n os.environ[\"MLFLOW_KEEP_RUN_ACTIVE\"] = \"False\"\n YOLO(\"yolo11n-cls.yaml\").train(data=\"imagenet10\", imgsz=32, epochs=1, plots=False, device=\"cpu\")\n status = mlflow.get_run(run_id=run_id).info.status\n assert status == \"FINISHED\", \"MLflow run should be ended when MLFLOW_KEEP_RUN_ACTIVE=False\"\n\n # Test with MLFLOW_KEEP_RUN_ACTIVE not set\n os.environ.pop(\"MLFLOW_KEEP_RUN_ACTIVE\", None)\n YOLO(\"yolo11n-cls.yaml\").train(data=\"imagenet10\", imgsz=32, epochs=1, plots=False, device=\"cpu\")\n status = mlflow.get_run(run_id=run_id).info.status\n assert status == \"FINISHED\", \"MLflow run should be ended by default when MLFLOW_KEEP_RUN_ACTIVE is not set\"\n SETTINGS[\"mlflow\"] = False", "chunk_type": "function", "name": "test_mlflow_keep_run_active", "file_path": "ultralytics\\tests\\test_integrations.py", "start_line": 43, "end_line": 70, "start_col": 0, "end_col": 30, "parent_name": null, "docstring": "Ensure MLflow run status matches MLFLOW_KEEP_RUN_ACTIVE environment variable settings.", "parameters": [], "return_type": null, "decorators": [ "pytest.mark.skipif(True, reason='Test failing in scheduled CI https://github.com/ultralytics/ultralytics/pull/8868')", "pytest.mark.skipif(not check_requirements('mlflow', install=False), reason='mlflow not installed')" ], "complexity_score": 1, "dependencies": [ "contextlib", "os", "subprocess", "time", "pathlib.Path", "pytest", "tests.MODEL", "tests.SOURCE", "tests.TMP", "ultralytics.YOLO", "ultralytics.download", "ultralytics.utils.DATASETS_DIR", "ultralytics.utils.SETTINGS", "ultralytics.utils.checks.check_requirements", "mlflow", "tritonclient.http.InferenceServerClient", "ultralytics.models.yolo.detect.DetectionValidator", "ultralytics.models.yolo.pose.PoseValidator", "ultralytics.models.yolo.segment.SegmentationValidator" ], "chunk_id": "function_test_mlflow_keep_run_active_1891feb7" }, { "content": "def test_triton():\n \"\"\"Test NVIDIA Triton Server functionalities with YOLO model.\"\"\"\n check_requirements(\"tritonclient[all]\")\n from tritonclient.http import InferenceServerClient # noqa\n\n # Create variables\n model_name = \"yolo\"\n triton_repo = TMP / \"triton_repo\" # Triton repo path\n triton_model = triton_repo / model_name # Triton model path\n\n # Export model to ONNX\n f = YOLO(MODEL).export(format=\"onnx\", dynamic=True)\n\n # Prepare Triton repo\n (triton_model / \"1\").mkdir(parents=True, exist_ok=True)\n Path(f).rename(triton_model / \"1\" / \"model.onnx\")\n (triton_model / \"config.pbtxt\").touch()\n\n # Define image https://catalog.ngc.nvidia.com/orgs/nvidia/containers/tritonserver\n tag = \"nvcr.io/nvidia/tritonserver:23.09-py3\" # 6.4 GB\n\n # Pull the image\n subprocess.call(f\"docker pull {tag}\", shell=True)\n\n # Run the Triton server and capture the container ID\n container_id = (\n subprocess.check_output(\n f\"docker run -d --rm -v {triton_repo}:/models -p 8000:8000 {tag} tritonserver --model-repository=/models\",\n shell=True,\n )\n .decode(\"utf-8\")\n .strip()\n )\n\n # Wait for the Triton server to start\n triton_client = InferenceServerClient(url=\"localhost:8000\", verbose=False, ssl=False)\n\n # Wait until model is ready\n for _ in range(10):\n with contextlib.suppress(Exception):\n assert triton_client.is_model_ready(model_name)\n break\n time.sleep(1)\n\n # Check Triton inference\n YOLO(f\"http://localhost:8000/{model_name}\", \"detect\")(SOURCE) # exported model inference\n\n # Kill and remove the container at the end of the test\n subprocess.call(f\"docker kill {container_id}\", shell=True)", "chunk_type": "function", "name": "test_triton", "file_path": "ultralytics\\tests\\test_integrations.py", "start_line": 74, "end_line": 122, "start_col": 0, "end_col": 62, "parent_name": null, "docstring": "Test NVIDIA Triton Server functionalities with YOLO model.", "parameters": [], "return_type": null, "decorators": [ "pytest.mark.skipif(not check_requirements('tritonclient', install=False), reason='tritonclient[all] not installed')" ], "complexity_score": 2, "dependencies": [ "contextlib", "os", "subprocess", "time", "pathlib.Path", "pytest", "tests.MODEL", "tests.SOURCE", "tests.TMP", "ultralytics.YOLO", "ultralytics.download", "ultralytics.utils.DATASETS_DIR", "ultralytics.utils.SETTINGS", "ultralytics.utils.checks.check_requirements", "mlflow", "tritonclient.http.InferenceServerClient", "ultralytics.models.yolo.detect.DetectionValidator", "ultralytics.models.yolo.pose.PoseValidator", "ultralytics.models.yolo.segment.SegmentationValidator" ], "chunk_id": "function_test_triton_9cf3bd6f" }, { "content": "def test_faster_coco_eval():\n \"\"\"Validate YOLO model predictions on COCO dataset using faster-coco-eval.\"\"\"\n from ultralytics.models.yolo.detect import DetectionValidator\n from ultralytics.models.yolo.pose import PoseValidator\n from ultralytics.models.yolo.segment import SegmentationValidator\n\n # Download annotations after each dataset downloads first\n url = \"https://github.com/ultralytics/assets/releases/download/v0.0.0/\"\n\n args = {\"model\": \"yolo11n.pt\", \"data\": \"coco8.yaml\", \"save_json\": True, \"imgsz\": 64}\n validator = DetectionValidator(args=args)\n validator()\n validator.is_coco = True\n download(f\"{url}instances_val2017.json\", dir=DATASETS_DIR / \"coco8/annotations\")\n _ = validator.eval_json(validator.stats)\n\n args = {\"model\": \"yolo11n-seg.pt\", \"data\": \"coco8-seg.yaml\", \"save_json\": True, \"imgsz\": 64}\n validator = SegmentationValidator(args=args)\n validator()\n validator.is_coco = True\n download(f\"{url}instances_val2017.json\", dir=DATASETS_DIR / \"coco8-seg/annotations\")\n _ = validator.eval_json(validator.stats)\n\n args = {\"model\": \"yolo11n-pose.pt\", \"data\": \"coco8-pose.yaml\", \"save_json\": True, \"imgsz\": 64}\n validator = PoseValidator(args=args)\n validator()\n validator.is_coco = True\n download(f\"{url}person_keypoints_val2017.json\", dir=DATASETS_DIR / \"coco8-pose/annotations\")\n _ = validator.eval_json(validator.stats)", "chunk_type": "function", "name": "test_faster_coco_eval", "file_path": "ultralytics\\tests\\test_integrations.py", "start_line": 126, "end_line": 154, "start_col": 0, "end_col": 44, "parent_name": null, "docstring": "Validate YOLO model predictions on COCO dataset using faster-coco-eval.", "parameters": [], "return_type": null, "decorators": [ "pytest.mark.skipif(not check_requirements('faster-coco-eval', install=False), reason='faster-coco-eval not installed')" ], "complexity_score": 1, "dependencies": [ "contextlib", "os", "subprocess", "time", "pathlib.Path", "pytest", "tests.MODEL", "tests.SOURCE", "tests.TMP", "ultralytics.YOLO", "ultralytics.download", "ultralytics.utils.DATASETS_DIR", "ultralytics.utils.SETTINGS", "ultralytics.utils.checks.check_requirements", "mlflow", "tritonclient.http.InferenceServerClient", "ultralytics.models.yolo.detect.DetectionValidator", "ultralytics.models.yolo.pose.PoseValidator", "ultralytics.models.yolo.segment.SegmentationValidator" ], "chunk_id": "function_test_faster_coco_eval_6b6ff6ea" }, { "content": "import contextlib", "chunk_type": "import", "name": "contextlib", "file_path": "ultralytics\\tests\\test_python.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 17, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_contextlib_f49cf65c" }, { "content": "import csv", "chunk_type": "import", "name": "csv", "file_path": "ultralytics\\tests\\test_python.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 10, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_csv_9e1d5aec" }, { "content": "import urllib", "chunk_type": "import", "name": "urllib", "file_path": "ultralytics\\tests\\test_python.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 13, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_urllib_02f67365" }, { "content": "from copy import copy", "chunk_type": "import", "name": "copy", "file_path": "ultralytics\\tests\\test_python.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 21, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_copy_d1dfadfe" }, { "content": "from pathlib import Path", "chunk_type": "import", "name": "Path", "file_path": "ultralytics\\tests\\test_python.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Path_c7bdfce6" }, { "content": "import cv2", "chunk_type": "import", "name": "cv2", "file_path": "ultralytics\\tests\\test_python.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 10, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_cv2_7ae2d929" }, { "content": "import numpy as np", "chunk_type": "import", "name": "numpy", "file_path": "ultralytics\\tests\\test_python.py", "start_line": 10, "end_line": 10, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_numpy_9d7eb1ce" }, { "content": "import pytest", "chunk_type": "import", "name": "pytest", "file_path": "ultralytics\\tests\\test_python.py", "start_line": 11, "end_line": 11, "start_col": 0, "end_col": 13, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_pytest_5188e459" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\tests\\test_python.py", "start_line": 12, "end_line": 12, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_4265cd33" }, { "content": "from PIL import Image", "chunk_type": "import", "name": "Image", "file_path": "ultralytics\\tests\\test_python.py", "start_line": 13, "end_line": 13, "start_col": 0, "end_col": 21, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Image_396c2821" }, { "content": "from tests import CFG, MODEL, MODELS, SOURCE, SOURCES_LIST, TASK_MODEL_DATA, TMP", "chunk_type": "import", "name": "CFG, MODEL, MODELS, SOURCE, SOURCES_LIST, TASK_MODEL_DATA, TMP", "file_path": "ultralytics\\tests\\test_python.py", "start_line": 15, "end_line": 15, "start_col": 0, "end_col": 80, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_CFG, MODEL, MODELS, SOURCE, SOURCES_LIST, TASK_MODEL_DATA, TMP_78679ffe" }, { "content": "from ultralytics import RTDETR, YOLO", "chunk_type": "import", "name": "RTDETR, YOLO", "file_path": "ultralytics\\tests\\test_python.py", "start_line": 16, "end_line": 16, "start_col": 0, "end_col": 36, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_RTDETR, YOLO_163bbc7a" }, { "content": "from ultralytics.cfg import TASK2DATA, TASKS", "chunk_type": "import", "name": "TASK2DATA, TASKS", "file_path": "ultralytics\\tests\\test_python.py", "start_line": 17, "end_line": 17, "start_col": 0, "end_col": 44, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_TASK2DATA, TASKS_6a33bbbc" }, { "content": "from ultralytics.data.build import load_inference_source", "chunk_type": "import", "name": "load_inference_source", "file_path": "ultralytics\\tests\\test_python.py", "start_line": 18, "end_line": 18, "start_col": 0, "end_col": 56, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_load_inference_source_c0ebf686" }, { "content": "from ultralytics.data.utils import check_det_dataset", "chunk_type": "import", "name": "check_det_dataset", "file_path": "ultralytics\\tests\\test_python.py", "start_line": 19, "end_line": 19, "start_col": 0, "end_col": 52, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_check_det_dataset_57105f5d" }, { "content": "from ultralytics.utils import (\n ARM64,\n ASSETS,\n DEFAULT_CFG,\n DEFAULT_CFG_PATH,\n LINUX,\n LOGGER,\n ONLINE,\n ROOT,\n WEIGHTS_DIR,\n WINDOWS,\n YAML,\n checks,\n is_dir_writeable,\n is_github_action_running,\n)", "chunk_type": "import", "name": "ARM64, ASSETS, DEFAULT_CFG, DEFAULT_CFG_PATH, LINUX, LOGGER, ONLINE, ROOT, WEIGHTS_DIR, WINDOWS, YAML, checks, is_dir_writeable, is_github_action_running", "file_path": "ultralytics\\tests\\test_python.py", "start_line": 20, "end_line": 35, "start_col": 0, "end_col": 1, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_ARM64, ASSETS, DEFAULT_CFG, DEFAULT_CFG_PATH, LINUX, LOGGER, ONLINE, ROOT, WEIGHTS_DIR, WINDOWS, YAML, checks, is_dir_writeable, is_github_action_running_86631bdb" }, { "content": "from ultralytics.utils.downloads import download", "chunk_type": "import", "name": "download", "file_path": "ultralytics\\tests\\test_python.py", "start_line": 36, "end_line": 36, "start_col": 0, "end_col": 48, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_download_08db52f2" }, { "content": "from ultralytics.utils.torch_utils import TORCH_1_9", "chunk_type": "import", "name": "TORCH_1_9", "file_path": "ultralytics\\tests\\test_python.py", "start_line": 37, "end_line": 37, "start_col": 0, "end_col": 51, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_TORCH_1_9_cc1f173a" }, { "content": "IS_TMP_WRITEABLE = is_dir_writeable(TMP) # WARNING: must be run once tests start as TMP does not exist on tests/init", "chunk_type": "variable", "name": "IS_TMP_WRITEABLE", "file_path": "ultralytics\\tests\\test_python.py", "start_line": 39, "end_line": 39, "start_col": 0, "end_col": 40, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_IS_TMP_WRITEABLE_4af38b88" }, { "content": "def test_model_forward():\n \"\"\"Test the forward pass of the YOLO model.\"\"\"\n model = YOLO(CFG)\n model(source=None, imgsz=32, augment=True) # also test no source and augment", "chunk_type": "function", "name": "test_model_forward", "file_path": "ultralytics\\tests\\test_python.py", "start_line": 42, "end_line": 45, "start_col": 0, "end_col": 46, "parent_name": null, "docstring": "Test the forward pass of the YOLO model.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "contextlib", "csv", "urllib", "copy.copy", "pathlib.Path", "cv2", "numpy", "pytest", "torch", "PIL.Image", "tests.CFG", "tests.MODEL", "tests.MODELS", "tests.SOURCE", "tests.SOURCES_LIST", "tests.TASK_MODEL_DATA", "tests.TMP", "ultralytics.RTDETR", "ultralytics.YOLO", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASKS", "ultralytics.data.build.load_inference_source", "ultralytics.data.utils.check_det_dataset", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.DEFAULT_CFG_PATH", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.ONLINE", "ultralytics.utils.ROOT", "ultralytics.utils.WEIGHTS_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.YAML", "ultralytics.utils.checks", "ultralytics.utils.is_dir_writeable", "ultralytics.utils.is_github_action_running", "ultralytics.utils.downloads.download", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.nn.tasks.DetectionModel", "ultralytics.data.split.autosplit", "ultralytics.data.utils.HUBDatasetStats", "ultralytics.utils.downloads.zip_directory", "ultralytics.data.converter.coco80_to_coco91_class", "ultralytics.data.converter.convert_coco", "ultralytics.data.annotator.auto_annotate", "ultralytics.hub.utils.Events", "ultralytics.cfg.check_dict_alignment", "ultralytics.cfg.copy_default_cfg", "ultralytics.cfg.smart_value", "ultralytics.utils.get_git_branch", "ultralytics.utils.get_git_origin_url", "ultralytics.utils.get_ubuntu_version", "ultralytics.utils.is_github_action_running", "ultralytics.utils.benchmarks.ProfileModels", "ultralytics.nn.modules.conv.Conv", "ultralytics.utils.torch_utils.get_flops_with_torch_profiler", "ultralytics.utils.torch_utils.profile_ops", "ultralytics.utils.torch_utils.time_sync", "ultralytics.utils.ops.ltwh2xywh", "ultralytics.utils.ops.ltwh2xyxy", "ultralytics.utils.ops.make_divisible", "ultralytics.utils.ops.xywh2ltwh", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xywhn2xyxy", "ultralytics.utils.ops.xywhr2xyxyxyxy", "ultralytics.utils.ops.xyxy2ltwh", "ultralytics.utils.ops.xyxy2xywh", "ultralytics.utils.ops.xyxy2xywhn", "ultralytics.utils.ops.xyxyxyxy2xywhr", "ultralytics.utils.files.file_age", "ultralytics.utils.files.file_date", "ultralytics.utils.files.get_latest_run", "ultralytics.utils.files.spaces_in_path", "unittest.mock.MagicMock", "unittest.mock.patch", "ultralytics.utils.patches.torch_save", "ultralytics.nn.modules.conv.CBAM", "ultralytics.nn.modules.conv.Conv2", "ultralytics.nn.modules.conv.ConvTranspose", "ultralytics.nn.modules.conv.DWConvTranspose2d", "ultralytics.nn.modules.conv.Focus", "ultralytics.nn.modules.block.C1", "ultralytics.nn.modules.block.C3TR", "ultralytics.nn.modules.block.BottleneckCSP", "ultralytics.nn.modules.block.C3Ghost", "ultralytics.nn.modules.block.C3x", "ultralytics.hub.export_fmts_hub", "ultralytics.hub.logout", "ultralytics.hub.utils.smart_request", "ultralytics.data.augment.classify_augmentations", "ultralytics.models.yolo.world.train_world.WorldTrainerFromScratch", "numpy", "ultralytics.YOLOE", "ultralytics.models.yolo.yoloe.YOLOEVPSegPredictor", "ultralytics.models.yolo.yoloe.YOLOEPESegTrainer" ], "chunk_id": "function_test_model_forward_00fa6ff7" }, { "content": "def test_model_methods():\n \"\"\"Test various methods and properties of the YOLO model to ensure correct functionality.\"\"\"\n model = YOLO(MODEL)\n\n # Model methods\n model.info(verbose=True, detailed=True)\n model = model.reset_weights()\n model = model.load(MODEL)\n model.to(\"cpu\")\n model.fuse()\n model.clear_callback(\"on_train_start\")\n model.reset_callbacks()\n\n # Model properties\n _ = model.names\n _ = model.device\n _ = model.transforms\n _ = model.task_map", "chunk_type": "function", "name": "test_model_methods", "file_path": "ultralytics\\tests\\test_python.py", "start_line": 48, "end_line": 65, "start_col": 0, "end_col": 22, "parent_name": null, "docstring": "Test various methods and properties of the YOLO model to ensure correct functionality.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "contextlib", "csv", "urllib", "copy.copy", "pathlib.Path", "cv2", "numpy", "pytest", "torch", "PIL.Image", "tests.CFG", "tests.MODEL", "tests.MODELS", "tests.SOURCE", "tests.SOURCES_LIST", "tests.TASK_MODEL_DATA", "tests.TMP", "ultralytics.RTDETR", "ultralytics.YOLO", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASKS", "ultralytics.data.build.load_inference_source", "ultralytics.data.utils.check_det_dataset", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.DEFAULT_CFG_PATH", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.ONLINE", "ultralytics.utils.ROOT", "ultralytics.utils.WEIGHTS_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.YAML", "ultralytics.utils.checks", "ultralytics.utils.is_dir_writeable", "ultralytics.utils.is_github_action_running", "ultralytics.utils.downloads.download", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.nn.tasks.DetectionModel", "ultralytics.data.split.autosplit", "ultralytics.data.utils.HUBDatasetStats", "ultralytics.utils.downloads.zip_directory", "ultralytics.data.converter.coco80_to_coco91_class", "ultralytics.data.converter.convert_coco", "ultralytics.data.annotator.auto_annotate", "ultralytics.hub.utils.Events", "ultralytics.cfg.check_dict_alignment", "ultralytics.cfg.copy_default_cfg", "ultralytics.cfg.smart_value", "ultralytics.utils.get_git_branch", "ultralytics.utils.get_git_origin_url", "ultralytics.utils.get_ubuntu_version", "ultralytics.utils.is_github_action_running", "ultralytics.utils.benchmarks.ProfileModels", "ultralytics.nn.modules.conv.Conv", "ultralytics.utils.torch_utils.get_flops_with_torch_profiler", "ultralytics.utils.torch_utils.profile_ops", "ultralytics.utils.torch_utils.time_sync", "ultralytics.utils.ops.ltwh2xywh", "ultralytics.utils.ops.ltwh2xyxy", "ultralytics.utils.ops.make_divisible", "ultralytics.utils.ops.xywh2ltwh", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xywhn2xyxy", "ultralytics.utils.ops.xywhr2xyxyxyxy", "ultralytics.utils.ops.xyxy2ltwh", "ultralytics.utils.ops.xyxy2xywh", "ultralytics.utils.ops.xyxy2xywhn", "ultralytics.utils.ops.xyxyxyxy2xywhr", "ultralytics.utils.files.file_age", "ultralytics.utils.files.file_date", "ultralytics.utils.files.get_latest_run", "ultralytics.utils.files.spaces_in_path", "unittest.mock.MagicMock", "unittest.mock.patch", "ultralytics.utils.patches.torch_save", "ultralytics.nn.modules.conv.CBAM", "ultralytics.nn.modules.conv.Conv2", "ultralytics.nn.modules.conv.ConvTranspose", "ultralytics.nn.modules.conv.DWConvTranspose2d", "ultralytics.nn.modules.conv.Focus", "ultralytics.nn.modules.block.C1", "ultralytics.nn.modules.block.C3TR", "ultralytics.nn.modules.block.BottleneckCSP", "ultralytics.nn.modules.block.C3Ghost", "ultralytics.nn.modules.block.C3x", "ultralytics.hub.export_fmts_hub", "ultralytics.hub.logout", "ultralytics.hub.utils.smart_request", "ultralytics.data.augment.classify_augmentations", "ultralytics.models.yolo.world.train_world.WorldTrainerFromScratch", "numpy", "ultralytics.YOLOE", "ultralytics.models.yolo.yoloe.YOLOEVPSegPredictor", "ultralytics.models.yolo.yoloe.YOLOEPESegTrainer" ], "chunk_id": "function_test_model_methods_c2ddf691" }, { "content": "def test_model_profile():\n \"\"\"Test profiling of the YOLO model with `profile=True` to assess performance and resource usage.\"\"\"\n from ultralytics.nn.tasks import DetectionModel\n\n model = DetectionModel() # build model\n im = torch.randn(1, 3, 64, 64) # requires min imgsz=64\n _ = model.predict(im, profile=True)", "chunk_type": "function", "name": "test_model_profile", "file_path": "ultralytics\\tests\\test_python.py", "start_line": 68, "end_line": 74, "start_col": 0, "end_col": 39, "parent_name": null, "docstring": "Test profiling of the YOLO model with `profile=True` to assess performance and resource usage.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "contextlib", "csv", "urllib", "copy.copy", "pathlib.Path", "cv2", "numpy", "pytest", "torch", "PIL.Image", "tests.CFG", "tests.MODEL", "tests.MODELS", "tests.SOURCE", "tests.SOURCES_LIST", "tests.TASK_MODEL_DATA", "tests.TMP", "ultralytics.RTDETR", "ultralytics.YOLO", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASKS", "ultralytics.data.build.load_inference_source", "ultralytics.data.utils.check_det_dataset", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.DEFAULT_CFG_PATH", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.ONLINE", "ultralytics.utils.ROOT", "ultralytics.utils.WEIGHTS_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.YAML", "ultralytics.utils.checks", "ultralytics.utils.is_dir_writeable", "ultralytics.utils.is_github_action_running", "ultralytics.utils.downloads.download", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.nn.tasks.DetectionModel", "ultralytics.data.split.autosplit", "ultralytics.data.utils.HUBDatasetStats", "ultralytics.utils.downloads.zip_directory", "ultralytics.data.converter.coco80_to_coco91_class", "ultralytics.data.converter.convert_coco", "ultralytics.data.annotator.auto_annotate", "ultralytics.hub.utils.Events", "ultralytics.cfg.check_dict_alignment", "ultralytics.cfg.copy_default_cfg", "ultralytics.cfg.smart_value", "ultralytics.utils.get_git_branch", "ultralytics.utils.get_git_origin_url", "ultralytics.utils.get_ubuntu_version", "ultralytics.utils.is_github_action_running", "ultralytics.utils.benchmarks.ProfileModels", "ultralytics.nn.modules.conv.Conv", "ultralytics.utils.torch_utils.get_flops_with_torch_profiler", "ultralytics.utils.torch_utils.profile_ops", "ultralytics.utils.torch_utils.time_sync", "ultralytics.utils.ops.ltwh2xywh", "ultralytics.utils.ops.ltwh2xyxy", "ultralytics.utils.ops.make_divisible", "ultralytics.utils.ops.xywh2ltwh", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xywhn2xyxy", "ultralytics.utils.ops.xywhr2xyxyxyxy", "ultralytics.utils.ops.xyxy2ltwh", "ultralytics.utils.ops.xyxy2xywh", "ultralytics.utils.ops.xyxy2xywhn", "ultralytics.utils.ops.xyxyxyxy2xywhr", "ultralytics.utils.files.file_age", "ultralytics.utils.files.file_date", "ultralytics.utils.files.get_latest_run", "ultralytics.utils.files.spaces_in_path", "unittest.mock.MagicMock", "unittest.mock.patch", "ultralytics.utils.patches.torch_save", "ultralytics.nn.modules.conv.CBAM", "ultralytics.nn.modules.conv.Conv2", "ultralytics.nn.modules.conv.ConvTranspose", "ultralytics.nn.modules.conv.DWConvTranspose2d", "ultralytics.nn.modules.conv.Focus", "ultralytics.nn.modules.block.C1", "ultralytics.nn.modules.block.C3TR", "ultralytics.nn.modules.block.BottleneckCSP", "ultralytics.nn.modules.block.C3Ghost", "ultralytics.nn.modules.block.C3x", "ultralytics.hub.export_fmts_hub", "ultralytics.hub.logout", "ultralytics.hub.utils.smart_request", "ultralytics.data.augment.classify_augmentations", "ultralytics.models.yolo.world.train_world.WorldTrainerFromScratch", "numpy", "ultralytics.YOLOE", "ultralytics.models.yolo.yoloe.YOLOEVPSegPredictor", "ultralytics.models.yolo.yoloe.YOLOEPESegTrainer" ], "chunk_id": "function_test_model_profile_57fc8d84" }, { "content": "def test_predict_txt():\n \"\"\"Test YOLO predictions with file, directory, and pattern sources listed in a text file.\"\"\"\n file = TMP / \"sources_multi_row.txt\"\n with open(file, \"w\") as f:\n for src in SOURCES_LIST:\n f.write(f\"{src}\\n\")\n results = YOLO(MODEL)(source=file, imgsz=32)\n assert len(results) == 7 # 1 + 2 + 2 + 2 = 7 images", "chunk_type": "function", "name": "test_predict_txt", "file_path": "ultralytics\\tests\\test_python.py", "start_line": 78, "end_line": 85, "start_col": 0, "end_col": 28, "parent_name": null, "docstring": "Test YOLO predictions with file, directory, and pattern sources listed in a text file.", "parameters": [], "return_type": null, "decorators": [ "pytest.mark.skipif(not IS_TMP_WRITEABLE, reason='directory is not writeable')" ], "complexity_score": 2, "dependencies": [ "contextlib", "csv", "urllib", "copy.copy", "pathlib.Path", "cv2", "numpy", "pytest", "torch", "PIL.Image", "tests.CFG", "tests.MODEL", "tests.MODELS", "tests.SOURCE", "tests.SOURCES_LIST", "tests.TASK_MODEL_DATA", "tests.TMP", "ultralytics.RTDETR", "ultralytics.YOLO", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASKS", "ultralytics.data.build.load_inference_source", "ultralytics.data.utils.check_det_dataset", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.DEFAULT_CFG_PATH", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.ONLINE", "ultralytics.utils.ROOT", "ultralytics.utils.WEIGHTS_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.YAML", "ultralytics.utils.checks", "ultralytics.utils.is_dir_writeable", "ultralytics.utils.is_github_action_running", "ultralytics.utils.downloads.download", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.nn.tasks.DetectionModel", "ultralytics.data.split.autosplit", "ultralytics.data.utils.HUBDatasetStats", "ultralytics.utils.downloads.zip_directory", "ultralytics.data.converter.coco80_to_coco91_class", "ultralytics.data.converter.convert_coco", "ultralytics.data.annotator.auto_annotate", "ultralytics.hub.utils.Events", "ultralytics.cfg.check_dict_alignment", "ultralytics.cfg.copy_default_cfg", "ultralytics.cfg.smart_value", "ultralytics.utils.get_git_branch", "ultralytics.utils.get_git_origin_url", "ultralytics.utils.get_ubuntu_version", "ultralytics.utils.is_github_action_running", "ultralytics.utils.benchmarks.ProfileModels", "ultralytics.nn.modules.conv.Conv", "ultralytics.utils.torch_utils.get_flops_with_torch_profiler", "ultralytics.utils.torch_utils.profile_ops", "ultralytics.utils.torch_utils.time_sync", "ultralytics.utils.ops.ltwh2xywh", "ultralytics.utils.ops.ltwh2xyxy", "ultralytics.utils.ops.make_divisible", "ultralytics.utils.ops.xywh2ltwh", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xywhn2xyxy", "ultralytics.utils.ops.xywhr2xyxyxyxy", "ultralytics.utils.ops.xyxy2ltwh", "ultralytics.utils.ops.xyxy2xywh", "ultralytics.utils.ops.xyxy2xywhn", "ultralytics.utils.ops.xyxyxyxy2xywhr", "ultralytics.utils.files.file_age", "ultralytics.utils.files.file_date", "ultralytics.utils.files.get_latest_run", "ultralytics.utils.files.spaces_in_path", "unittest.mock.MagicMock", "unittest.mock.patch", "ultralytics.utils.patches.torch_save", "ultralytics.nn.modules.conv.CBAM", "ultralytics.nn.modules.conv.Conv2", "ultralytics.nn.modules.conv.ConvTranspose", "ultralytics.nn.modules.conv.DWConvTranspose2d", "ultralytics.nn.modules.conv.Focus", "ultralytics.nn.modules.block.C1", "ultralytics.nn.modules.block.C3TR", "ultralytics.nn.modules.block.BottleneckCSP", "ultralytics.nn.modules.block.C3Ghost", "ultralytics.nn.modules.block.C3x", "ultralytics.hub.export_fmts_hub", "ultralytics.hub.logout", "ultralytics.hub.utils.smart_request", "ultralytics.data.augment.classify_augmentations", "ultralytics.models.yolo.world.train_world.WorldTrainerFromScratch", "numpy", "ultralytics.YOLOE", "ultralytics.models.yolo.yoloe.YOLOEVPSegPredictor", "ultralytics.models.yolo.yoloe.YOLOEPESegTrainer" ], "chunk_id": "function_test_predict_txt_8aff8783" }, { "content": "def test_predict_csv_multi_row():\n \"\"\"Test YOLO predictions with sources listed in multiple rows of a CSV file.\"\"\"\n file = TMP / \"sources_multi_row.csv\"\n with open(file, \"w\", newline=\"\") as f:\n writer = csv.writer(f)\n writer.writerow([\"source\"])\n writer.writerows([[src] for src in SOURCES_LIST])\n results = YOLO(MODEL)(source=file, imgsz=32)\n assert len(results) == 7 # 1 + 2 + 2 + 2 = 7 images", "chunk_type": "function", "name": "test_predict_csv_multi_row", "file_path": "ultralytics\\tests\\test_python.py", "start_line": 90, "end_line": 98, "start_col": 0, "end_col": 28, "parent_name": null, "docstring": "Test YOLO predictions with sources listed in multiple rows of a CSV file.", "parameters": [], "return_type": null, "decorators": [ "pytest.mark.skipif(True, reason='disabled for testing')", "pytest.mark.skipif(not IS_TMP_WRITEABLE, reason='directory is not writeable')" ], "complexity_score": 2, "dependencies": [ "contextlib", "csv", "urllib", "copy.copy", "pathlib.Path", "cv2", "numpy", "pytest", "torch", "PIL.Image", "tests.CFG", "tests.MODEL", "tests.MODELS", "tests.SOURCE", "tests.SOURCES_LIST", "tests.TASK_MODEL_DATA", "tests.TMP", "ultralytics.RTDETR", "ultralytics.YOLO", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASKS", "ultralytics.data.build.load_inference_source", "ultralytics.data.utils.check_det_dataset", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.DEFAULT_CFG_PATH", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.ONLINE", "ultralytics.utils.ROOT", "ultralytics.utils.WEIGHTS_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.YAML", "ultralytics.utils.checks", "ultralytics.utils.is_dir_writeable", "ultralytics.utils.is_github_action_running", "ultralytics.utils.downloads.download", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.nn.tasks.DetectionModel", "ultralytics.data.split.autosplit", "ultralytics.data.utils.HUBDatasetStats", "ultralytics.utils.downloads.zip_directory", "ultralytics.data.converter.coco80_to_coco91_class", "ultralytics.data.converter.convert_coco", "ultralytics.data.annotator.auto_annotate", "ultralytics.hub.utils.Events", "ultralytics.cfg.check_dict_alignment", "ultralytics.cfg.copy_default_cfg", "ultralytics.cfg.smart_value", "ultralytics.utils.get_git_branch", "ultralytics.utils.get_git_origin_url", "ultralytics.utils.get_ubuntu_version", "ultralytics.utils.is_github_action_running", "ultralytics.utils.benchmarks.ProfileModels", "ultralytics.nn.modules.conv.Conv", "ultralytics.utils.torch_utils.get_flops_with_torch_profiler", "ultralytics.utils.torch_utils.profile_ops", "ultralytics.utils.torch_utils.time_sync", "ultralytics.utils.ops.ltwh2xywh", "ultralytics.utils.ops.ltwh2xyxy", "ultralytics.utils.ops.make_divisible", "ultralytics.utils.ops.xywh2ltwh", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xywhn2xyxy", "ultralytics.utils.ops.xywhr2xyxyxyxy", "ultralytics.utils.ops.xyxy2ltwh", "ultralytics.utils.ops.xyxy2xywh", "ultralytics.utils.ops.xyxy2xywhn", "ultralytics.utils.ops.xyxyxyxy2xywhr", "ultralytics.utils.files.file_age", "ultralytics.utils.files.file_date", "ultralytics.utils.files.get_latest_run", "ultralytics.utils.files.spaces_in_path", "unittest.mock.MagicMock", "unittest.mock.patch", "ultralytics.utils.patches.torch_save", "ultralytics.nn.modules.conv.CBAM", "ultralytics.nn.modules.conv.Conv2", "ultralytics.nn.modules.conv.ConvTranspose", "ultralytics.nn.modules.conv.DWConvTranspose2d", "ultralytics.nn.modules.conv.Focus", "ultralytics.nn.modules.block.C1", "ultralytics.nn.modules.block.C3TR", "ultralytics.nn.modules.block.BottleneckCSP", "ultralytics.nn.modules.block.C3Ghost", "ultralytics.nn.modules.block.C3x", "ultralytics.hub.export_fmts_hub", "ultralytics.hub.logout", "ultralytics.hub.utils.smart_request", "ultralytics.data.augment.classify_augmentations", "ultralytics.models.yolo.world.train_world.WorldTrainerFromScratch", "numpy", "ultralytics.YOLOE", "ultralytics.models.yolo.yoloe.YOLOEVPSegPredictor", "ultralytics.models.yolo.yoloe.YOLOEPESegTrainer" ], "chunk_id": "function_test_predict_csv_multi_row_af609b05" }, { "content": "def test_predict_csv_single_row():\n \"\"\"Test YOLO predictions with sources listed in a single row of a CSV file.\"\"\"\n file = TMP / \"sources_single_row.csv\"\n with open(file, \"w\", newline=\"\") as f:\n writer = csv.writer(f)\n writer.writerow(SOURCES_LIST)\n results = YOLO(MODEL)(source=file, imgsz=32)\n assert len(results) == 7 # 1 + 2 + 2 + 2 = 7 images", "chunk_type": "function", "name": "test_predict_csv_single_row", "file_path": "ultralytics\\tests\\test_python.py", "start_line": 103, "end_line": 110, "start_col": 0, "end_col": 28, "parent_name": null, "docstring": "Test YOLO predictions with sources listed in a single row of a CSV file.", "parameters": [], "return_type": null, "decorators": [ "pytest.mark.skipif(True, reason='disabled for testing')", "pytest.mark.skipif(not IS_TMP_WRITEABLE, reason='directory is not writeable')" ], "complexity_score": 1, "dependencies": [ "contextlib", "csv", "urllib", "copy.copy", "pathlib.Path", "cv2", "numpy", "pytest", "torch", "PIL.Image", "tests.CFG", "tests.MODEL", "tests.MODELS", "tests.SOURCE", "tests.SOURCES_LIST", "tests.TASK_MODEL_DATA", "tests.TMP", "ultralytics.RTDETR", "ultralytics.YOLO", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASKS", "ultralytics.data.build.load_inference_source", "ultralytics.data.utils.check_det_dataset", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.DEFAULT_CFG_PATH", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.ONLINE", "ultralytics.utils.ROOT", "ultralytics.utils.WEIGHTS_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.YAML", "ultralytics.utils.checks", "ultralytics.utils.is_dir_writeable", "ultralytics.utils.is_github_action_running", "ultralytics.utils.downloads.download", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.nn.tasks.DetectionModel", "ultralytics.data.split.autosplit", "ultralytics.data.utils.HUBDatasetStats", "ultralytics.utils.downloads.zip_directory", "ultralytics.data.converter.coco80_to_coco91_class", "ultralytics.data.converter.convert_coco", "ultralytics.data.annotator.auto_annotate", "ultralytics.hub.utils.Events", "ultralytics.cfg.check_dict_alignment", "ultralytics.cfg.copy_default_cfg", "ultralytics.cfg.smart_value", "ultralytics.utils.get_git_branch", "ultralytics.utils.get_git_origin_url", "ultralytics.utils.get_ubuntu_version", "ultralytics.utils.is_github_action_running", "ultralytics.utils.benchmarks.ProfileModels", "ultralytics.nn.modules.conv.Conv", "ultralytics.utils.torch_utils.get_flops_with_torch_profiler", "ultralytics.utils.torch_utils.profile_ops", "ultralytics.utils.torch_utils.time_sync", "ultralytics.utils.ops.ltwh2xywh", "ultralytics.utils.ops.ltwh2xyxy", "ultralytics.utils.ops.make_divisible", "ultralytics.utils.ops.xywh2ltwh", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xywhn2xyxy", "ultralytics.utils.ops.xywhr2xyxyxyxy", "ultralytics.utils.ops.xyxy2ltwh", "ultralytics.utils.ops.xyxy2xywh", "ultralytics.utils.ops.xyxy2xywhn", "ultralytics.utils.ops.xyxyxyxy2xywhr", "ultralytics.utils.files.file_age", "ultralytics.utils.files.file_date", "ultralytics.utils.files.get_latest_run", "ultralytics.utils.files.spaces_in_path", "unittest.mock.MagicMock", "unittest.mock.patch", "ultralytics.utils.patches.torch_save", "ultralytics.nn.modules.conv.CBAM", "ultralytics.nn.modules.conv.Conv2", "ultralytics.nn.modules.conv.ConvTranspose", "ultralytics.nn.modules.conv.DWConvTranspose2d", "ultralytics.nn.modules.conv.Focus", "ultralytics.nn.modules.block.C1", "ultralytics.nn.modules.block.C3TR", "ultralytics.nn.modules.block.BottleneckCSP", "ultralytics.nn.modules.block.C3Ghost", "ultralytics.nn.modules.block.C3x", "ultralytics.hub.export_fmts_hub", "ultralytics.hub.logout", "ultralytics.hub.utils.smart_request", "ultralytics.data.augment.classify_augmentations", "ultralytics.models.yolo.world.train_world.WorldTrainerFromScratch", "numpy", "ultralytics.YOLOE", "ultralytics.models.yolo.yoloe.YOLOEVPSegPredictor", "ultralytics.models.yolo.yoloe.YOLOEPESegTrainer" ], "chunk_id": "function_test_predict_csv_single_row_597fec3f" }, { "content": "def test_predict_img(model_name):\n \"\"\"Test YOLO model predictions on various image input types and sources, including online images.\"\"\"\n channels = 1 if model_name == \"yolo11n-grayscale.pt\" else 3\n model = YOLO(WEIGHTS_DIR / model_name)\n im = cv2.imread(str(SOURCE), flags=cv2.IMREAD_GRAYSCALE if channels == 1 else cv2.IMREAD_COLOR) # uint8 numpy array\n assert len(model(source=Image.open(SOURCE), save=True, verbose=True, imgsz=32)) == 1 # PIL\n assert len(model(source=im, save=True, save_txt=True, imgsz=32)) == 1 # ndarray\n assert len(model(torch.rand((2, channels, 32, 32)), imgsz=32)) == 2 # batch-size 2 Tensor, FP32 0.0-1.0 RGB order\n assert len(model(source=[im, im], save=True, save_txt=True, imgsz=32)) == 2 # batch\n assert len(list(model(source=[im, im], save=True, stream=True, imgsz=32))) == 2 # stream\n assert len(model(torch.zeros(320, 640, channels).numpy().astype(np.uint8), imgsz=32)) == 1 # tensor to numpy\n batch = [\n str(SOURCE), # filename\n Path(SOURCE), # Path\n \"https://github.com/ultralytics/assets/releases/download/v0.0.0/zidane.jpg\" if ONLINE else SOURCE, # URI\n im, # OpenCV\n Image.open(SOURCE), # PIL\n np.zeros((320, 640, channels), dtype=np.uint8), # numpy\n ]\n assert len(model(batch, imgsz=32, classes=0)) == len(batch) # multiple sources in a batch", "chunk_type": "function", "name": "test_predict_img", "file_path": "ultralytics\\tests\\test_python.py", "start_line": 114, "end_line": 133, "start_col": 0, "end_col": 63, "parent_name": null, "docstring": "Test YOLO model predictions on various image input types and sources, including online images.", "parameters": [ "model_name" ], "return_type": null, "decorators": [ "pytest.mark.parametrize('model_name', MODELS)" ], "complexity_score": 1, "dependencies": [ "contextlib", "csv", "urllib", "copy.copy", "pathlib.Path", "cv2", "numpy", "pytest", "torch", "PIL.Image", "tests.CFG", "tests.MODEL", "tests.MODELS", "tests.SOURCE", "tests.SOURCES_LIST", "tests.TASK_MODEL_DATA", "tests.TMP", "ultralytics.RTDETR", "ultralytics.YOLO", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASKS", "ultralytics.data.build.load_inference_source", "ultralytics.data.utils.check_det_dataset", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.DEFAULT_CFG_PATH", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.ONLINE", "ultralytics.utils.ROOT", "ultralytics.utils.WEIGHTS_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.YAML", "ultralytics.utils.checks", "ultralytics.utils.is_dir_writeable", "ultralytics.utils.is_github_action_running", "ultralytics.utils.downloads.download", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.nn.tasks.DetectionModel", "ultralytics.data.split.autosplit", "ultralytics.data.utils.HUBDatasetStats", "ultralytics.utils.downloads.zip_directory", "ultralytics.data.converter.coco80_to_coco91_class", "ultralytics.data.converter.convert_coco", "ultralytics.data.annotator.auto_annotate", "ultralytics.hub.utils.Events", "ultralytics.cfg.check_dict_alignment", "ultralytics.cfg.copy_default_cfg", "ultralytics.cfg.smart_value", "ultralytics.utils.get_git_branch", "ultralytics.utils.get_git_origin_url", "ultralytics.utils.get_ubuntu_version", "ultralytics.utils.is_github_action_running", "ultralytics.utils.benchmarks.ProfileModels", "ultralytics.nn.modules.conv.Conv", "ultralytics.utils.torch_utils.get_flops_with_torch_profiler", "ultralytics.utils.torch_utils.profile_ops", "ultralytics.utils.torch_utils.time_sync", "ultralytics.utils.ops.ltwh2xywh", "ultralytics.utils.ops.ltwh2xyxy", "ultralytics.utils.ops.make_divisible", "ultralytics.utils.ops.xywh2ltwh", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xywhn2xyxy", "ultralytics.utils.ops.xywhr2xyxyxyxy", "ultralytics.utils.ops.xyxy2ltwh", "ultralytics.utils.ops.xyxy2xywh", "ultralytics.utils.ops.xyxy2xywhn", "ultralytics.utils.ops.xyxyxyxy2xywhr", "ultralytics.utils.files.file_age", "ultralytics.utils.files.file_date", "ultralytics.utils.files.get_latest_run", "ultralytics.utils.files.spaces_in_path", "unittest.mock.MagicMock", "unittest.mock.patch", "ultralytics.utils.patches.torch_save", "ultralytics.nn.modules.conv.CBAM", "ultralytics.nn.modules.conv.Conv2", "ultralytics.nn.modules.conv.ConvTranspose", "ultralytics.nn.modules.conv.DWConvTranspose2d", "ultralytics.nn.modules.conv.Focus", "ultralytics.nn.modules.block.C1", "ultralytics.nn.modules.block.C3TR", "ultralytics.nn.modules.block.BottleneckCSP", "ultralytics.nn.modules.block.C3Ghost", "ultralytics.nn.modules.block.C3x", "ultralytics.hub.export_fmts_hub", "ultralytics.hub.logout", "ultralytics.hub.utils.smart_request", "ultralytics.data.augment.classify_augmentations", "ultralytics.models.yolo.world.train_world.WorldTrainerFromScratch", "numpy", "ultralytics.YOLOE", "ultralytics.models.yolo.yoloe.YOLOEVPSegPredictor", "ultralytics.models.yolo.yoloe.YOLOEPESegTrainer" ], "chunk_id": "function_test_predict_img_38830bfe" }, { "content": "def test_predict_visualize(model):\n \"\"\"Test model prediction methods with 'visualize=True' to generate and display prediction visualizations.\"\"\"\n YOLO(WEIGHTS_DIR / model)(SOURCE, imgsz=32, visualize=True)", "chunk_type": "function", "name": "test_predict_visualize", "file_path": "ultralytics\\tests\\test_python.py", "start_line": 137, "end_line": 139, "start_col": 0, "end_col": 63, "parent_name": null, "docstring": "Test model prediction methods with 'visualize=True' to generate and display prediction visualizations.", "parameters": [ "model" ], "return_type": null, "decorators": [ "pytest.mark.parametrize('model', MODELS)" ], "complexity_score": 1, "dependencies": [ "contextlib", "csv", "urllib", "copy.copy", "pathlib.Path", "cv2", "numpy", "pytest", "torch", "PIL.Image", "tests.CFG", "tests.MODEL", "tests.MODELS", "tests.SOURCE", "tests.SOURCES_LIST", "tests.TASK_MODEL_DATA", "tests.TMP", "ultralytics.RTDETR", "ultralytics.YOLO", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASKS", "ultralytics.data.build.load_inference_source", "ultralytics.data.utils.check_det_dataset", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.DEFAULT_CFG_PATH", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.ONLINE", "ultralytics.utils.ROOT", "ultralytics.utils.WEIGHTS_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.YAML", "ultralytics.utils.checks", "ultralytics.utils.is_dir_writeable", "ultralytics.utils.is_github_action_running", "ultralytics.utils.downloads.download", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.nn.tasks.DetectionModel", "ultralytics.data.split.autosplit", "ultralytics.data.utils.HUBDatasetStats", "ultralytics.utils.downloads.zip_directory", "ultralytics.data.converter.coco80_to_coco91_class", "ultralytics.data.converter.convert_coco", "ultralytics.data.annotator.auto_annotate", "ultralytics.hub.utils.Events", "ultralytics.cfg.check_dict_alignment", "ultralytics.cfg.copy_default_cfg", "ultralytics.cfg.smart_value", "ultralytics.utils.get_git_branch", "ultralytics.utils.get_git_origin_url", "ultralytics.utils.get_ubuntu_version", "ultralytics.utils.is_github_action_running", "ultralytics.utils.benchmarks.ProfileModels", "ultralytics.nn.modules.conv.Conv", "ultralytics.utils.torch_utils.get_flops_with_torch_profiler", "ultralytics.utils.torch_utils.profile_ops", "ultralytics.utils.torch_utils.time_sync", "ultralytics.utils.ops.ltwh2xywh", "ultralytics.utils.ops.ltwh2xyxy", "ultralytics.utils.ops.make_divisible", "ultralytics.utils.ops.xywh2ltwh", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xywhn2xyxy", "ultralytics.utils.ops.xywhr2xyxyxyxy", "ultralytics.utils.ops.xyxy2ltwh", "ultralytics.utils.ops.xyxy2xywh", "ultralytics.utils.ops.xyxy2xywhn", "ultralytics.utils.ops.xyxyxyxy2xywhr", "ultralytics.utils.files.file_age", "ultralytics.utils.files.file_date", "ultralytics.utils.files.get_latest_run", "ultralytics.utils.files.spaces_in_path", "unittest.mock.MagicMock", "unittest.mock.patch", "ultralytics.utils.patches.torch_save", "ultralytics.nn.modules.conv.CBAM", "ultralytics.nn.modules.conv.Conv2", "ultralytics.nn.modules.conv.ConvTranspose", "ultralytics.nn.modules.conv.DWConvTranspose2d", "ultralytics.nn.modules.conv.Focus", "ultralytics.nn.modules.block.C1", "ultralytics.nn.modules.block.C3TR", "ultralytics.nn.modules.block.BottleneckCSP", "ultralytics.nn.modules.block.C3Ghost", "ultralytics.nn.modules.block.C3x", "ultralytics.hub.export_fmts_hub", "ultralytics.hub.logout", "ultralytics.hub.utils.smart_request", "ultralytics.data.augment.classify_augmentations", "ultralytics.models.yolo.world.train_world.WorldTrainerFromScratch", "numpy", "ultralytics.YOLOE", "ultralytics.models.yolo.yoloe.YOLOEVPSegPredictor", "ultralytics.models.yolo.yoloe.YOLOEPESegTrainer" ], "chunk_id": "function_test_predict_visualize_be082d7d" }, { "content": "def test_predict_grey_and_4ch():\n \"\"\"Test YOLO prediction on SOURCE converted to greyscale and 4-channel images with various filenames.\"\"\"\n im = Image.open(SOURCE)\n directory = TMP / \"im4\"\n directory.mkdir(parents=True, exist_ok=True)\n\n source_greyscale = directory / \"greyscale.jpg\"\n source_rgba = directory / \"4ch.png\"\n source_non_utf = directory / \"non_UTF_测试文件_tést_image.jpg\"\n source_spaces = directory / \"image with spaces.jpg\"\n\n im.convert(\"L\").save(source_greyscale) # greyscale\n im.convert(\"RGBA\").save(source_rgba) # 4-ch PNG with alpha\n im.save(source_non_utf) # non-UTF characters in filename\n im.save(source_spaces) # spaces in filename\n\n # Inference\n model = YOLO(MODEL)\n for f in source_rgba, source_greyscale, source_non_utf, source_spaces:\n for source in Image.open(f), cv2.imread(str(f)), f:\n results = model(source, save=True, verbose=True, imgsz=32)\n assert len(results) == 1 # verify that an image was run\n f.unlink() # cleanup", "chunk_type": "function", "name": "test_predict_grey_and_4ch", "file_path": "ultralytics\\tests\\test_python.py", "start_line": 142, "end_line": 164, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": "Test YOLO prediction on SOURCE converted to greyscale and 4-channel images with various filenames.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 3, "dependencies": [ "contextlib", "csv", "urllib", "copy.copy", "pathlib.Path", "cv2", "numpy", "pytest", "torch", "PIL.Image", "tests.CFG", "tests.MODEL", "tests.MODELS", "tests.SOURCE", "tests.SOURCES_LIST", "tests.TASK_MODEL_DATA", "tests.TMP", "ultralytics.RTDETR", "ultralytics.YOLO", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASKS", "ultralytics.data.build.load_inference_source", "ultralytics.data.utils.check_det_dataset", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.DEFAULT_CFG_PATH", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.ONLINE", "ultralytics.utils.ROOT", "ultralytics.utils.WEIGHTS_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.YAML", "ultralytics.utils.checks", "ultralytics.utils.is_dir_writeable", "ultralytics.utils.is_github_action_running", "ultralytics.utils.downloads.download", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.nn.tasks.DetectionModel", "ultralytics.data.split.autosplit", "ultralytics.data.utils.HUBDatasetStats", "ultralytics.utils.downloads.zip_directory", "ultralytics.data.converter.coco80_to_coco91_class", "ultralytics.data.converter.convert_coco", "ultralytics.data.annotator.auto_annotate", "ultralytics.hub.utils.Events", "ultralytics.cfg.check_dict_alignment", "ultralytics.cfg.copy_default_cfg", "ultralytics.cfg.smart_value", "ultralytics.utils.get_git_branch", "ultralytics.utils.get_git_origin_url", "ultralytics.utils.get_ubuntu_version", "ultralytics.utils.is_github_action_running", "ultralytics.utils.benchmarks.ProfileModels", "ultralytics.nn.modules.conv.Conv", "ultralytics.utils.torch_utils.get_flops_with_torch_profiler", "ultralytics.utils.torch_utils.profile_ops", "ultralytics.utils.torch_utils.time_sync", "ultralytics.utils.ops.ltwh2xywh", "ultralytics.utils.ops.ltwh2xyxy", "ultralytics.utils.ops.make_divisible", "ultralytics.utils.ops.xywh2ltwh", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xywhn2xyxy", "ultralytics.utils.ops.xywhr2xyxyxyxy", "ultralytics.utils.ops.xyxy2ltwh", "ultralytics.utils.ops.xyxy2xywh", "ultralytics.utils.ops.xyxy2xywhn", "ultralytics.utils.ops.xyxyxyxy2xywhr", "ultralytics.utils.files.file_age", "ultralytics.utils.files.file_date", "ultralytics.utils.files.get_latest_run", "ultralytics.utils.files.spaces_in_path", "unittest.mock.MagicMock", "unittest.mock.patch", "ultralytics.utils.patches.torch_save", "ultralytics.nn.modules.conv.CBAM", "ultralytics.nn.modules.conv.Conv2", "ultralytics.nn.modules.conv.ConvTranspose", "ultralytics.nn.modules.conv.DWConvTranspose2d", "ultralytics.nn.modules.conv.Focus", "ultralytics.nn.modules.block.C1", "ultralytics.nn.modules.block.C3TR", "ultralytics.nn.modules.block.BottleneckCSP", "ultralytics.nn.modules.block.C3Ghost", "ultralytics.nn.modules.block.C3x", "ultralytics.hub.export_fmts_hub", "ultralytics.hub.logout", "ultralytics.hub.utils.smart_request", "ultralytics.data.augment.classify_augmentations", "ultralytics.models.yolo.world.train_world.WorldTrainerFromScratch", "numpy", "ultralytics.YOLOE", "ultralytics.models.yolo.yoloe.YOLOEVPSegPredictor", "ultralytics.models.yolo.yoloe.YOLOEPESegTrainer" ], "chunk_id": "function_test_predict_grey_and_4ch_af224775" }, { "content": "def test_youtube():\n \"\"\"Test YOLO model on a YouTube video stream, handling potential network-related errors.\"\"\"\n model = YOLO(MODEL)\n try:\n model.predict(\"https://youtu.be/G17sBkb38XQ\", imgsz=96, save=True)\n # Handle internet connection errors and 'urllib.error.HTTPError: HTTP Error 429: Too Many Requests'\n except (urllib.error.HTTPError, ConnectionError) as e:\n LOGGER.error(f\"YouTube Test Error: {e}\")", "chunk_type": "function", "name": "test_youtube", "file_path": "ultralytics\\tests\\test_python.py", "start_line": 170, "end_line": 177, "start_col": 0, "end_col": 48, "parent_name": null, "docstring": "Test YOLO model on a YouTube video stream, handling potential network-related errors.", "parameters": [], "return_type": null, "decorators": [ "pytest.mark.slow", "pytest.mark.skipif(not ONLINE, reason='environment is offline')", "pytest.mark.skipif(is_github_action_running(), reason='No auth https://github.com/JuanBindez/pytubefix/issues/166')" ], "complexity_score": 2, "dependencies": [ "contextlib", "csv", "urllib", "copy.copy", "pathlib.Path", "cv2", "numpy", "pytest", "torch", "PIL.Image", "tests.CFG", "tests.MODEL", "tests.MODELS", "tests.SOURCE", "tests.SOURCES_LIST", "tests.TASK_MODEL_DATA", "tests.TMP", "ultralytics.RTDETR", "ultralytics.YOLO", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASKS", "ultralytics.data.build.load_inference_source", "ultralytics.data.utils.check_det_dataset", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.DEFAULT_CFG_PATH", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.ONLINE", "ultralytics.utils.ROOT", "ultralytics.utils.WEIGHTS_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.YAML", "ultralytics.utils.checks", "ultralytics.utils.is_dir_writeable", "ultralytics.utils.is_github_action_running", "ultralytics.utils.downloads.download", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.nn.tasks.DetectionModel", "ultralytics.data.split.autosplit", "ultralytics.data.utils.HUBDatasetStats", "ultralytics.utils.downloads.zip_directory", "ultralytics.data.converter.coco80_to_coco91_class", "ultralytics.data.converter.convert_coco", "ultralytics.data.annotator.auto_annotate", "ultralytics.hub.utils.Events", "ultralytics.cfg.check_dict_alignment", "ultralytics.cfg.copy_default_cfg", "ultralytics.cfg.smart_value", "ultralytics.utils.get_git_branch", "ultralytics.utils.get_git_origin_url", "ultralytics.utils.get_ubuntu_version", "ultralytics.utils.is_github_action_running", "ultralytics.utils.benchmarks.ProfileModels", "ultralytics.nn.modules.conv.Conv", "ultralytics.utils.torch_utils.get_flops_with_torch_profiler", "ultralytics.utils.torch_utils.profile_ops", "ultralytics.utils.torch_utils.time_sync", "ultralytics.utils.ops.ltwh2xywh", "ultralytics.utils.ops.ltwh2xyxy", "ultralytics.utils.ops.make_divisible", "ultralytics.utils.ops.xywh2ltwh", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xywhn2xyxy", "ultralytics.utils.ops.xywhr2xyxyxyxy", "ultralytics.utils.ops.xyxy2ltwh", "ultralytics.utils.ops.xyxy2xywh", "ultralytics.utils.ops.xyxy2xywhn", "ultralytics.utils.ops.xyxyxyxy2xywhr", "ultralytics.utils.files.file_age", "ultralytics.utils.files.file_date", "ultralytics.utils.files.get_latest_run", "ultralytics.utils.files.spaces_in_path", "unittest.mock.MagicMock", "unittest.mock.patch", "ultralytics.utils.patches.torch_save", "ultralytics.nn.modules.conv.CBAM", "ultralytics.nn.modules.conv.Conv2", "ultralytics.nn.modules.conv.ConvTranspose", "ultralytics.nn.modules.conv.DWConvTranspose2d", "ultralytics.nn.modules.conv.Focus", "ultralytics.nn.modules.block.C1", "ultralytics.nn.modules.block.C3TR", "ultralytics.nn.modules.block.BottleneckCSP", "ultralytics.nn.modules.block.C3Ghost", "ultralytics.nn.modules.block.C3x", "ultralytics.hub.export_fmts_hub", "ultralytics.hub.logout", "ultralytics.hub.utils.smart_request", "ultralytics.data.augment.classify_augmentations", "ultralytics.models.yolo.world.train_world.WorldTrainerFromScratch", "numpy", "ultralytics.YOLOE", "ultralytics.models.yolo.yoloe.YOLOEVPSegPredictor", "ultralytics.models.yolo.yoloe.YOLOEPESegTrainer" ], "chunk_id": "function_test_youtube_073c4f2a" }, { "content": "def test_track_stream(model):\n \"\"\"\n Test streaming tracking on a short 10 frame video using ByteTrack tracker and different GMC methods.\n\n Note imgsz=160 required for tracking for higher confidence and better matches.\n \"\"\"\n if model == \"yolo11n-cls.pt\": # classification model not supported for tracking\n return\n video_url = \"https://github.com/ultralytics/assets/releases/download/v0.0.0/decelera_portrait_min.mov\"\n model = YOLO(model)\n model.track(video_url, imgsz=160, tracker=\"bytetrack.yaml\")\n model.track(video_url, imgsz=160, tracker=\"botsort.yaml\", save_frames=True) # test frame saving also\n\n # Test Global Motion Compensation (GMC) methods and ReID\n for gmc, reidm in zip([\"orb\", \"sift\", \"ecc\"], [\"auto\", \"auto\", \"yolo11n-cls.pt\"]):\n default_args = YAML.load(ROOT / \"cfg/trackers/botsort.yaml\")\n custom_yaml = TMP / f\"botsort-{gmc}.yaml\"\n YAML.save(custom_yaml, {**default_args, \"gmc_method\": gmc, \"with_reid\": True, \"model\": reidm})\n model.track(video_url, imgsz=160, tracker=custom_yaml)", "chunk_type": "function", "name": "test_track_stream", "file_path": "ultralytics\\tests\\test_python.py", "start_line": 183, "end_line": 201, "start_col": 0, "end_col": 62, "parent_name": null, "docstring": "Test streaming tracking on a short 10 frame video using ByteTrack tracker and different GMC methods.\n\nNote imgsz=160 required for tracking for higher confidence and better matches.", "parameters": [ "model" ], "return_type": null, "decorators": [ "pytest.mark.skipif(not ONLINE, reason='environment is offline')", "pytest.mark.skipif(not IS_TMP_WRITEABLE, reason='directory is not writeable')", "pytest.mark.parametrize('model', MODELS)" ], "complexity_score": 3, "dependencies": [ "contextlib", "csv", "urllib", "copy.copy", "pathlib.Path", "cv2", "numpy", "pytest", "torch", "PIL.Image", "tests.CFG", "tests.MODEL", "tests.MODELS", "tests.SOURCE", "tests.SOURCES_LIST", "tests.TASK_MODEL_DATA", "tests.TMP", "ultralytics.RTDETR", "ultralytics.YOLO", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASKS", "ultralytics.data.build.load_inference_source", "ultralytics.data.utils.check_det_dataset", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.DEFAULT_CFG_PATH", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.ONLINE", "ultralytics.utils.ROOT", "ultralytics.utils.WEIGHTS_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.YAML", "ultralytics.utils.checks", "ultralytics.utils.is_dir_writeable", "ultralytics.utils.is_github_action_running", "ultralytics.utils.downloads.download", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.nn.tasks.DetectionModel", "ultralytics.data.split.autosplit", "ultralytics.data.utils.HUBDatasetStats", "ultralytics.utils.downloads.zip_directory", "ultralytics.data.converter.coco80_to_coco91_class", "ultralytics.data.converter.convert_coco", "ultralytics.data.annotator.auto_annotate", "ultralytics.hub.utils.Events", "ultralytics.cfg.check_dict_alignment", "ultralytics.cfg.copy_default_cfg", "ultralytics.cfg.smart_value", "ultralytics.utils.get_git_branch", "ultralytics.utils.get_git_origin_url", "ultralytics.utils.get_ubuntu_version", "ultralytics.utils.is_github_action_running", "ultralytics.utils.benchmarks.ProfileModels", "ultralytics.nn.modules.conv.Conv", "ultralytics.utils.torch_utils.get_flops_with_torch_profiler", "ultralytics.utils.torch_utils.profile_ops", "ultralytics.utils.torch_utils.time_sync", "ultralytics.utils.ops.ltwh2xywh", "ultralytics.utils.ops.ltwh2xyxy", "ultralytics.utils.ops.make_divisible", "ultralytics.utils.ops.xywh2ltwh", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xywhn2xyxy", "ultralytics.utils.ops.xywhr2xyxyxyxy", "ultralytics.utils.ops.xyxy2ltwh", "ultralytics.utils.ops.xyxy2xywh", "ultralytics.utils.ops.xyxy2xywhn", "ultralytics.utils.ops.xyxyxyxy2xywhr", "ultralytics.utils.files.file_age", "ultralytics.utils.files.file_date", "ultralytics.utils.files.get_latest_run", "ultralytics.utils.files.spaces_in_path", "unittest.mock.MagicMock", "unittest.mock.patch", "ultralytics.utils.patches.torch_save", "ultralytics.nn.modules.conv.CBAM", "ultralytics.nn.modules.conv.Conv2", "ultralytics.nn.modules.conv.ConvTranspose", "ultralytics.nn.modules.conv.DWConvTranspose2d", "ultralytics.nn.modules.conv.Focus", "ultralytics.nn.modules.block.C1", "ultralytics.nn.modules.block.C3TR", "ultralytics.nn.modules.block.BottleneckCSP", "ultralytics.nn.modules.block.C3Ghost", "ultralytics.nn.modules.block.C3x", "ultralytics.hub.export_fmts_hub", "ultralytics.hub.logout", "ultralytics.hub.utils.smart_request", "ultralytics.data.augment.classify_augmentations", "ultralytics.models.yolo.world.train_world.WorldTrainerFromScratch", "numpy", "ultralytics.YOLOE", "ultralytics.models.yolo.yoloe.YOLOEVPSegPredictor", "ultralytics.models.yolo.yoloe.YOLOEPESegTrainer" ], "chunk_id": "function_test_track_stream_86522184" }, { "content": "def test_val(task: str, model: str, data: str) -> None:\n \"\"\"Test the validation mode of the YOLO model.\"\"\"\n for plots in {True, False}: # Test both cases i.e. plots=True and plots=False\n metrics = YOLO(model).val(data=data, imgsz=32, plots=plots)\n metrics.to_df()\n metrics.to_csv()\n metrics.to_xml()\n metrics.to_html()\n metrics.to_json()\n metrics.to_sql()\n metrics.confusion_matrix.to_df() # Tests for confusion matrix export\n metrics.confusion_matrix.to_csv()\n metrics.confusion_matrix.to_xml()\n metrics.confusion_matrix.to_html()\n metrics.confusion_matrix.to_json()\n metrics.confusion_matrix.to_sql()", "chunk_type": "function", "name": "test_val", "file_path": "ultralytics\\tests\\test_python.py", "start_line": 205, "end_line": 220, "start_col": 0, "end_col": 41, "parent_name": null, "docstring": "Test the validation mode of the YOLO model.", "parameters": [ "task: str", "model: str", "data: str" ], "return_type": "None", "decorators": [ "pytest.mark.parametrize('task,model,data', TASK_MODEL_DATA)" ], "complexity_score": 2, "dependencies": [ "contextlib", "csv", "urllib", "copy.copy", "pathlib.Path", "cv2", "numpy", "pytest", "torch", "PIL.Image", "tests.CFG", "tests.MODEL", "tests.MODELS", "tests.SOURCE", "tests.SOURCES_LIST", "tests.TASK_MODEL_DATA", "tests.TMP", "ultralytics.RTDETR", "ultralytics.YOLO", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASKS", "ultralytics.data.build.load_inference_source", "ultralytics.data.utils.check_det_dataset", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.DEFAULT_CFG_PATH", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.ONLINE", "ultralytics.utils.ROOT", "ultralytics.utils.WEIGHTS_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.YAML", "ultralytics.utils.checks", "ultralytics.utils.is_dir_writeable", "ultralytics.utils.is_github_action_running", "ultralytics.utils.downloads.download", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.nn.tasks.DetectionModel", "ultralytics.data.split.autosplit", "ultralytics.data.utils.HUBDatasetStats", "ultralytics.utils.downloads.zip_directory", "ultralytics.data.converter.coco80_to_coco91_class", "ultralytics.data.converter.convert_coco", "ultralytics.data.annotator.auto_annotate", "ultralytics.hub.utils.Events", "ultralytics.cfg.check_dict_alignment", "ultralytics.cfg.copy_default_cfg", "ultralytics.cfg.smart_value", "ultralytics.utils.get_git_branch", "ultralytics.utils.get_git_origin_url", "ultralytics.utils.get_ubuntu_version", "ultralytics.utils.is_github_action_running", "ultralytics.utils.benchmarks.ProfileModels", "ultralytics.nn.modules.conv.Conv", "ultralytics.utils.torch_utils.get_flops_with_torch_profiler", "ultralytics.utils.torch_utils.profile_ops", "ultralytics.utils.torch_utils.time_sync", "ultralytics.utils.ops.ltwh2xywh", "ultralytics.utils.ops.ltwh2xyxy", "ultralytics.utils.ops.make_divisible", "ultralytics.utils.ops.xywh2ltwh", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xywhn2xyxy", "ultralytics.utils.ops.xywhr2xyxyxyxy", "ultralytics.utils.ops.xyxy2ltwh", "ultralytics.utils.ops.xyxy2xywh", "ultralytics.utils.ops.xyxy2xywhn", "ultralytics.utils.ops.xyxyxyxy2xywhr", "ultralytics.utils.files.file_age", "ultralytics.utils.files.file_date", "ultralytics.utils.files.get_latest_run", "ultralytics.utils.files.spaces_in_path", "unittest.mock.MagicMock", "unittest.mock.patch", "ultralytics.utils.patches.torch_save", "ultralytics.nn.modules.conv.CBAM", "ultralytics.nn.modules.conv.Conv2", "ultralytics.nn.modules.conv.ConvTranspose", "ultralytics.nn.modules.conv.DWConvTranspose2d", "ultralytics.nn.modules.conv.Focus", "ultralytics.nn.modules.block.C1", "ultralytics.nn.modules.block.C3TR", "ultralytics.nn.modules.block.BottleneckCSP", "ultralytics.nn.modules.block.C3Ghost", "ultralytics.nn.modules.block.C3x", "ultralytics.hub.export_fmts_hub", "ultralytics.hub.logout", "ultralytics.hub.utils.smart_request", "ultralytics.data.augment.classify_augmentations", "ultralytics.models.yolo.world.train_world.WorldTrainerFromScratch", "numpy", "ultralytics.YOLOE", "ultralytics.models.yolo.yoloe.YOLOEVPSegPredictor", "ultralytics.models.yolo.yoloe.YOLOEPESegTrainer" ], "chunk_id": "function_test_val_144aecb9" }, { "content": "def test_train_scratch():\n \"\"\"Test training the YOLO model from scratch using the provided configuration.\"\"\"\n model = YOLO(CFG)\n model.train(data=\"coco8.yaml\", epochs=2, imgsz=32, cache=\"disk\", batch=-1, close_mosaic=1, name=\"model\")\n model(SOURCE)", "chunk_type": "function", "name": "test_train_scratch", "file_path": "ultralytics\\tests\\test_python.py", "start_line": 223, "end_line": 227, "start_col": 0, "end_col": 17, "parent_name": null, "docstring": "Test training the YOLO model from scratch using the provided configuration.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "contextlib", "csv", "urllib", "copy.copy", "pathlib.Path", "cv2", "numpy", "pytest", "torch", "PIL.Image", "tests.CFG", "tests.MODEL", "tests.MODELS", "tests.SOURCE", "tests.SOURCES_LIST", "tests.TASK_MODEL_DATA", "tests.TMP", "ultralytics.RTDETR", "ultralytics.YOLO", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASKS", "ultralytics.data.build.load_inference_source", "ultralytics.data.utils.check_det_dataset", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.DEFAULT_CFG_PATH", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.ONLINE", "ultralytics.utils.ROOT", "ultralytics.utils.WEIGHTS_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.YAML", "ultralytics.utils.checks", "ultralytics.utils.is_dir_writeable", "ultralytics.utils.is_github_action_running", "ultralytics.utils.downloads.download", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.nn.tasks.DetectionModel", "ultralytics.data.split.autosplit", "ultralytics.data.utils.HUBDatasetStats", "ultralytics.utils.downloads.zip_directory", "ultralytics.data.converter.coco80_to_coco91_class", "ultralytics.data.converter.convert_coco", "ultralytics.data.annotator.auto_annotate", "ultralytics.hub.utils.Events", "ultralytics.cfg.check_dict_alignment", "ultralytics.cfg.copy_default_cfg", "ultralytics.cfg.smart_value", "ultralytics.utils.get_git_branch", "ultralytics.utils.get_git_origin_url", "ultralytics.utils.get_ubuntu_version", "ultralytics.utils.is_github_action_running", "ultralytics.utils.benchmarks.ProfileModels", "ultralytics.nn.modules.conv.Conv", "ultralytics.utils.torch_utils.get_flops_with_torch_profiler", "ultralytics.utils.torch_utils.profile_ops", "ultralytics.utils.torch_utils.time_sync", "ultralytics.utils.ops.ltwh2xywh", "ultralytics.utils.ops.ltwh2xyxy", "ultralytics.utils.ops.make_divisible", "ultralytics.utils.ops.xywh2ltwh", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xywhn2xyxy", "ultralytics.utils.ops.xywhr2xyxyxyxy", "ultralytics.utils.ops.xyxy2ltwh", "ultralytics.utils.ops.xyxy2xywh", "ultralytics.utils.ops.xyxy2xywhn", "ultralytics.utils.ops.xyxyxyxy2xywhr", "ultralytics.utils.files.file_age", "ultralytics.utils.files.file_date", "ultralytics.utils.files.get_latest_run", "ultralytics.utils.files.spaces_in_path", "unittest.mock.MagicMock", "unittest.mock.patch", "ultralytics.utils.patches.torch_save", "ultralytics.nn.modules.conv.CBAM", "ultralytics.nn.modules.conv.Conv2", "ultralytics.nn.modules.conv.ConvTranspose", "ultralytics.nn.modules.conv.DWConvTranspose2d", "ultralytics.nn.modules.conv.Focus", "ultralytics.nn.modules.block.C1", "ultralytics.nn.modules.block.C3TR", "ultralytics.nn.modules.block.BottleneckCSP", "ultralytics.nn.modules.block.C3Ghost", "ultralytics.nn.modules.block.C3x", "ultralytics.hub.export_fmts_hub", "ultralytics.hub.logout", "ultralytics.hub.utils.smart_request", "ultralytics.data.augment.classify_augmentations", "ultralytics.models.yolo.world.train_world.WorldTrainerFromScratch", "numpy", "ultralytics.YOLOE", "ultralytics.models.yolo.yoloe.YOLOEVPSegPredictor", "ultralytics.models.yolo.yoloe.YOLOEPESegTrainer" ], "chunk_id": "function_test_train_scratch_33a0cb32" }, { "content": "def test_train_pretrained(scls):\n \"\"\"Test training of the YOLO model starting from a pre-trained checkpoint.\"\"\"\n model = YOLO(WEIGHTS_DIR / \"yolo11n-seg.pt\")\n model.train(\n data=\"coco8-seg.yaml\", epochs=1, imgsz=32, cache=\"ram\", copy_paste=0.5, mixup=0.5, name=0, single_cls=scls\n )\n model(SOURCE)", "chunk_type": "function", "name": "test_train_pretrained", "file_path": "ultralytics\\tests\\test_python.py", "start_line": 231, "end_line": 237, "start_col": 0, "end_col": 17, "parent_name": null, "docstring": "Test training of the YOLO model starting from a pre-trained checkpoint.", "parameters": [ "scls" ], "return_type": null, "decorators": [ "pytest.mark.parametrize('scls', [False, True])" ], "complexity_score": 1, "dependencies": [ "contextlib", "csv", "urllib", "copy.copy", "pathlib.Path", "cv2", "numpy", "pytest", "torch", "PIL.Image", "tests.CFG", "tests.MODEL", "tests.MODELS", "tests.SOURCE", "tests.SOURCES_LIST", "tests.TASK_MODEL_DATA", "tests.TMP", "ultralytics.RTDETR", "ultralytics.YOLO", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASKS", "ultralytics.data.build.load_inference_source", "ultralytics.data.utils.check_det_dataset", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.DEFAULT_CFG_PATH", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.ONLINE", "ultralytics.utils.ROOT", "ultralytics.utils.WEIGHTS_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.YAML", "ultralytics.utils.checks", "ultralytics.utils.is_dir_writeable", "ultralytics.utils.is_github_action_running", "ultralytics.utils.downloads.download", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.nn.tasks.DetectionModel", "ultralytics.data.split.autosplit", "ultralytics.data.utils.HUBDatasetStats", "ultralytics.utils.downloads.zip_directory", "ultralytics.data.converter.coco80_to_coco91_class", "ultralytics.data.converter.convert_coco", "ultralytics.data.annotator.auto_annotate", "ultralytics.hub.utils.Events", "ultralytics.cfg.check_dict_alignment", "ultralytics.cfg.copy_default_cfg", "ultralytics.cfg.smart_value", "ultralytics.utils.get_git_branch", "ultralytics.utils.get_git_origin_url", "ultralytics.utils.get_ubuntu_version", "ultralytics.utils.is_github_action_running", "ultralytics.utils.benchmarks.ProfileModels", "ultralytics.nn.modules.conv.Conv", "ultralytics.utils.torch_utils.get_flops_with_torch_profiler", "ultralytics.utils.torch_utils.profile_ops", "ultralytics.utils.torch_utils.time_sync", "ultralytics.utils.ops.ltwh2xywh", "ultralytics.utils.ops.ltwh2xyxy", "ultralytics.utils.ops.make_divisible", "ultralytics.utils.ops.xywh2ltwh", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xywhn2xyxy", "ultralytics.utils.ops.xywhr2xyxyxyxy", "ultralytics.utils.ops.xyxy2ltwh", "ultralytics.utils.ops.xyxy2xywh", "ultralytics.utils.ops.xyxy2xywhn", "ultralytics.utils.ops.xyxyxyxy2xywhr", "ultralytics.utils.files.file_age", "ultralytics.utils.files.file_date", "ultralytics.utils.files.get_latest_run", "ultralytics.utils.files.spaces_in_path", "unittest.mock.MagicMock", "unittest.mock.patch", "ultralytics.utils.patches.torch_save", "ultralytics.nn.modules.conv.CBAM", "ultralytics.nn.modules.conv.Conv2", "ultralytics.nn.modules.conv.ConvTranspose", "ultralytics.nn.modules.conv.DWConvTranspose2d", "ultralytics.nn.modules.conv.Focus", "ultralytics.nn.modules.block.C1", "ultralytics.nn.modules.block.C3TR", "ultralytics.nn.modules.block.BottleneckCSP", "ultralytics.nn.modules.block.C3Ghost", "ultralytics.nn.modules.block.C3x", "ultralytics.hub.export_fmts_hub", "ultralytics.hub.logout", "ultralytics.hub.utils.smart_request", "ultralytics.data.augment.classify_augmentations", "ultralytics.models.yolo.world.train_world.WorldTrainerFromScratch", "numpy", "ultralytics.YOLOE", "ultralytics.models.yolo.yoloe.YOLOEVPSegPredictor", "ultralytics.models.yolo.yoloe.YOLOEPESegTrainer" ], "chunk_id": "function_test_train_pretrained_0760c91f" }, { "content": "def test_all_model_yamls():\n \"\"\"Test YOLO model creation for all available YAML configurations in the `cfg/models` directory.\"\"\"\n for m in (ROOT / \"cfg\" / \"models\").rglob(\"*.yaml\"):\n if \"rtdetr\" in m.name:\n if TORCH_1_9: # torch<=1.8 issue - TypeError: __init__() got an unexpected keyword argument 'batch_first'\n _ = RTDETR(m.name)(SOURCE, imgsz=640) # must be 640\n else:\n YOLO(m.name)", "chunk_type": "function", "name": "test_all_model_yamls", "file_path": "ultralytics\\tests\\test_python.py", "start_line": 240, "end_line": 247, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": "Test YOLO model creation for all available YAML configurations in the `cfg/models` directory.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 4, "dependencies": [ "contextlib", "csv", "urllib", "copy.copy", "pathlib.Path", "cv2", "numpy", "pytest", "torch", "PIL.Image", "tests.CFG", "tests.MODEL", "tests.MODELS", "tests.SOURCE", "tests.SOURCES_LIST", "tests.TASK_MODEL_DATA", "tests.TMP", "ultralytics.RTDETR", "ultralytics.YOLO", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASKS", "ultralytics.data.build.load_inference_source", "ultralytics.data.utils.check_det_dataset", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.DEFAULT_CFG_PATH", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.ONLINE", "ultralytics.utils.ROOT", "ultralytics.utils.WEIGHTS_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.YAML", "ultralytics.utils.checks", "ultralytics.utils.is_dir_writeable", "ultralytics.utils.is_github_action_running", "ultralytics.utils.downloads.download", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.nn.tasks.DetectionModel", "ultralytics.data.split.autosplit", "ultralytics.data.utils.HUBDatasetStats", "ultralytics.utils.downloads.zip_directory", "ultralytics.data.converter.coco80_to_coco91_class", "ultralytics.data.converter.convert_coco", "ultralytics.data.annotator.auto_annotate", "ultralytics.hub.utils.Events", "ultralytics.cfg.check_dict_alignment", "ultralytics.cfg.copy_default_cfg", "ultralytics.cfg.smart_value", "ultralytics.utils.get_git_branch", "ultralytics.utils.get_git_origin_url", "ultralytics.utils.get_ubuntu_version", "ultralytics.utils.is_github_action_running", "ultralytics.utils.benchmarks.ProfileModels", "ultralytics.nn.modules.conv.Conv", "ultralytics.utils.torch_utils.get_flops_with_torch_profiler", "ultralytics.utils.torch_utils.profile_ops", "ultralytics.utils.torch_utils.time_sync", "ultralytics.utils.ops.ltwh2xywh", "ultralytics.utils.ops.ltwh2xyxy", "ultralytics.utils.ops.make_divisible", "ultralytics.utils.ops.xywh2ltwh", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xywhn2xyxy", "ultralytics.utils.ops.xywhr2xyxyxyxy", "ultralytics.utils.ops.xyxy2ltwh", "ultralytics.utils.ops.xyxy2xywh", "ultralytics.utils.ops.xyxy2xywhn", "ultralytics.utils.ops.xyxyxyxy2xywhr", "ultralytics.utils.files.file_age", "ultralytics.utils.files.file_date", "ultralytics.utils.files.get_latest_run", "ultralytics.utils.files.spaces_in_path", "unittest.mock.MagicMock", "unittest.mock.patch", "ultralytics.utils.patches.torch_save", "ultralytics.nn.modules.conv.CBAM", "ultralytics.nn.modules.conv.Conv2", "ultralytics.nn.modules.conv.ConvTranspose", "ultralytics.nn.modules.conv.DWConvTranspose2d", "ultralytics.nn.modules.conv.Focus", "ultralytics.nn.modules.block.C1", "ultralytics.nn.modules.block.C3TR", "ultralytics.nn.modules.block.BottleneckCSP", "ultralytics.nn.modules.block.C3Ghost", "ultralytics.nn.modules.block.C3x", "ultralytics.hub.export_fmts_hub", "ultralytics.hub.logout", "ultralytics.hub.utils.smart_request", "ultralytics.data.augment.classify_augmentations", "ultralytics.models.yolo.world.train_world.WorldTrainerFromScratch", "numpy", "ultralytics.YOLOE", "ultralytics.models.yolo.yoloe.YOLOEVPSegPredictor", "ultralytics.models.yolo.yoloe.YOLOEPESegTrainer" ], "chunk_id": "function_test_all_model_yamls_16f388f4" }, { "content": "def test_workflow():\n \"\"\"Test the complete workflow including training, validation, prediction, and exporting.\"\"\"\n model = YOLO(MODEL)\n model.train(data=\"coco8.yaml\", epochs=1, imgsz=32, optimizer=\"SGD\")\n model.val(imgsz=32)\n model.predict(SOURCE, imgsz=32)\n model.export(format=\"torchscript\") # WARNING: Windows slow CI export bug", "chunk_type": "function", "name": "test_workflow", "file_path": "ultralytics\\tests\\test_python.py", "start_line": 251, "end_line": 257, "start_col": 0, "end_col": 38, "parent_name": null, "docstring": "Test the complete workflow including training, validation, prediction, and exporting.", "parameters": [], "return_type": null, "decorators": [ "pytest.mark.skipif(WINDOWS, reason='Windows slow CI export bug https://github.com/ultralytics/ultralytics/pull/16003')" ], "complexity_score": 1, "dependencies": [ "contextlib", "csv", "urllib", "copy.copy", "pathlib.Path", "cv2", "numpy", "pytest", "torch", "PIL.Image", "tests.CFG", "tests.MODEL", "tests.MODELS", "tests.SOURCE", "tests.SOURCES_LIST", "tests.TASK_MODEL_DATA", "tests.TMP", "ultralytics.RTDETR", "ultralytics.YOLO", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASKS", "ultralytics.data.build.load_inference_source", "ultralytics.data.utils.check_det_dataset", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.DEFAULT_CFG_PATH", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.ONLINE", "ultralytics.utils.ROOT", "ultralytics.utils.WEIGHTS_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.YAML", "ultralytics.utils.checks", "ultralytics.utils.is_dir_writeable", "ultralytics.utils.is_github_action_running", "ultralytics.utils.downloads.download", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.nn.tasks.DetectionModel", "ultralytics.data.split.autosplit", "ultralytics.data.utils.HUBDatasetStats", "ultralytics.utils.downloads.zip_directory", "ultralytics.data.converter.coco80_to_coco91_class", "ultralytics.data.converter.convert_coco", "ultralytics.data.annotator.auto_annotate", "ultralytics.hub.utils.Events", "ultralytics.cfg.check_dict_alignment", "ultralytics.cfg.copy_default_cfg", "ultralytics.cfg.smart_value", "ultralytics.utils.get_git_branch", "ultralytics.utils.get_git_origin_url", "ultralytics.utils.get_ubuntu_version", "ultralytics.utils.is_github_action_running", "ultralytics.utils.benchmarks.ProfileModels", "ultralytics.nn.modules.conv.Conv", "ultralytics.utils.torch_utils.get_flops_with_torch_profiler", "ultralytics.utils.torch_utils.profile_ops", "ultralytics.utils.torch_utils.time_sync", "ultralytics.utils.ops.ltwh2xywh", "ultralytics.utils.ops.ltwh2xyxy", "ultralytics.utils.ops.make_divisible", "ultralytics.utils.ops.xywh2ltwh", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xywhn2xyxy", "ultralytics.utils.ops.xywhr2xyxyxyxy", "ultralytics.utils.ops.xyxy2ltwh", "ultralytics.utils.ops.xyxy2xywh", "ultralytics.utils.ops.xyxy2xywhn", "ultralytics.utils.ops.xyxyxyxy2xywhr", "ultralytics.utils.files.file_age", "ultralytics.utils.files.file_date", "ultralytics.utils.files.get_latest_run", "ultralytics.utils.files.spaces_in_path", "unittest.mock.MagicMock", "unittest.mock.patch", "ultralytics.utils.patches.torch_save", "ultralytics.nn.modules.conv.CBAM", "ultralytics.nn.modules.conv.Conv2", "ultralytics.nn.modules.conv.ConvTranspose", "ultralytics.nn.modules.conv.DWConvTranspose2d", "ultralytics.nn.modules.conv.Focus", "ultralytics.nn.modules.block.C1", "ultralytics.nn.modules.block.C3TR", "ultralytics.nn.modules.block.BottleneckCSP", "ultralytics.nn.modules.block.C3Ghost", "ultralytics.nn.modules.block.C3x", "ultralytics.hub.export_fmts_hub", "ultralytics.hub.logout", "ultralytics.hub.utils.smart_request", "ultralytics.data.augment.classify_augmentations", "ultralytics.models.yolo.world.train_world.WorldTrainerFromScratch", "numpy", "ultralytics.YOLOE", "ultralytics.models.yolo.yoloe.YOLOEVPSegPredictor", "ultralytics.models.yolo.yoloe.YOLOEPESegTrainer" ], "chunk_id": "function_test_workflow_0fecee9c" }, { "content": "def test_predict_callback_and_setup():\n \"\"\"Test callback functionality during YOLO prediction setup and execution.\"\"\"\n\n def on_predict_batch_end(predictor):\n \"\"\"Callback function that handles operations at the end of a prediction batch.\"\"\"\n path, im0s, _ = predictor.batch\n im0s = im0s if isinstance(im0s, list) else [im0s]\n bs = [predictor.dataset.bs for _ in range(len(path))]\n predictor.results = zip(predictor.results, im0s, bs) # results is List[batch_size]\n\n model = YOLO(MODEL)\n model.add_callback(\"on_predict_batch_end\", on_predict_batch_end)\n\n dataset = load_inference_source(source=SOURCE)\n bs = dataset.bs # noqa access predictor properties\n results = model.predict(dataset, stream=True, imgsz=160) # source already setup\n for r, im0, bs in results:\n print(\"test_callback\", im0.shape)\n print(\"test_callback\", bs)\n boxes = r.boxes # Boxes object for bbox outputs\n print(boxes)", "chunk_type": "function", "name": "test_predict_callback_and_setup", "file_path": "ultralytics\\tests\\test_python.py", "start_line": 260, "end_line": 280, "start_col": 0, "end_col": 20, "parent_name": null, "docstring": "Test callback functionality during YOLO prediction setup and execution.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 3, "dependencies": [ "contextlib", "csv", "urllib", "copy.copy", "pathlib.Path", "cv2", "numpy", "pytest", "torch", "PIL.Image", "tests.CFG", "tests.MODEL", "tests.MODELS", "tests.SOURCE", "tests.SOURCES_LIST", "tests.TASK_MODEL_DATA", "tests.TMP", "ultralytics.RTDETR", "ultralytics.YOLO", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASKS", "ultralytics.data.build.load_inference_source", "ultralytics.data.utils.check_det_dataset", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.DEFAULT_CFG_PATH", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.ONLINE", "ultralytics.utils.ROOT", "ultralytics.utils.WEIGHTS_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.YAML", "ultralytics.utils.checks", "ultralytics.utils.is_dir_writeable", "ultralytics.utils.is_github_action_running", "ultralytics.utils.downloads.download", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.nn.tasks.DetectionModel", "ultralytics.data.split.autosplit", "ultralytics.data.utils.HUBDatasetStats", "ultralytics.utils.downloads.zip_directory", "ultralytics.data.converter.coco80_to_coco91_class", "ultralytics.data.converter.convert_coco", "ultralytics.data.annotator.auto_annotate", "ultralytics.hub.utils.Events", "ultralytics.cfg.check_dict_alignment", "ultralytics.cfg.copy_default_cfg", "ultralytics.cfg.smart_value", "ultralytics.utils.get_git_branch", "ultralytics.utils.get_git_origin_url", "ultralytics.utils.get_ubuntu_version", "ultralytics.utils.is_github_action_running", "ultralytics.utils.benchmarks.ProfileModels", "ultralytics.nn.modules.conv.Conv", "ultralytics.utils.torch_utils.get_flops_with_torch_profiler", "ultralytics.utils.torch_utils.profile_ops", "ultralytics.utils.torch_utils.time_sync", "ultralytics.utils.ops.ltwh2xywh", "ultralytics.utils.ops.ltwh2xyxy", "ultralytics.utils.ops.make_divisible", "ultralytics.utils.ops.xywh2ltwh", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xywhn2xyxy", "ultralytics.utils.ops.xywhr2xyxyxyxy", "ultralytics.utils.ops.xyxy2ltwh", "ultralytics.utils.ops.xyxy2xywh", "ultralytics.utils.ops.xyxy2xywhn", "ultralytics.utils.ops.xyxyxyxy2xywhr", "ultralytics.utils.files.file_age", "ultralytics.utils.files.file_date", "ultralytics.utils.files.get_latest_run", "ultralytics.utils.files.spaces_in_path", "unittest.mock.MagicMock", "unittest.mock.patch", "ultralytics.utils.patches.torch_save", "ultralytics.nn.modules.conv.CBAM", "ultralytics.nn.modules.conv.Conv2", "ultralytics.nn.modules.conv.ConvTranspose", "ultralytics.nn.modules.conv.DWConvTranspose2d", "ultralytics.nn.modules.conv.Focus", "ultralytics.nn.modules.block.C1", "ultralytics.nn.modules.block.C3TR", "ultralytics.nn.modules.block.BottleneckCSP", "ultralytics.nn.modules.block.C3Ghost", "ultralytics.nn.modules.block.C3x", "ultralytics.hub.export_fmts_hub", "ultralytics.hub.logout", "ultralytics.hub.utils.smart_request", "ultralytics.data.augment.classify_augmentations", "ultralytics.models.yolo.world.train_world.WorldTrainerFromScratch", "numpy", "ultralytics.YOLOE", "ultralytics.models.yolo.yoloe.YOLOEVPSegPredictor", "ultralytics.models.yolo.yoloe.YOLOEPESegTrainer" ], "chunk_id": "function_test_predict_callback_and_setup_0a909c98" }, { "content": "def test_results(model: str):\n \"\"\"Test YOLO model results processing and output in various formats.\"\"\"\n temp_s = \"https://ultralytics.com/images/boats.jpg\" if model == \"yolo11n-obb.pt\" else SOURCE\n results = YOLO(WEIGHTS_DIR / model)([temp_s, temp_s], imgsz=160)\n for r in results:\n assert len(r), f\"'{model}' results should not be empty!\"\n r = r.cpu().numpy()\n print(r, len(r), r.path) # print numpy attributes\n r = r.to(device=\"cpu\", dtype=torch.float32)\n r.save_txt(txt_file=TMP / \"runs/tests/label.txt\", save_conf=True)\n r.save_crop(save_dir=TMP / \"runs/tests/crops/\")\n r.to_df(decimals=3) # Align to_ methods: https://docs.ultralytics.com/modes/predict/#working-with-results\n r.to_csv()\n r.to_xml()\n r.to_html()\n r.to_json(normalize=True)\n r.to_sql()\n r.plot(pil=True, save=True, filename=TMP / \"results_plot_save.jpg\")\n r.plot(conf=True, boxes=True)\n print(r, len(r), r.path) # print after methods", "chunk_type": "function", "name": "test_results", "file_path": "ultralytics\\tests\\test_python.py", "start_line": 284, "end_line": 303, "start_col": 0, "end_col": 32, "parent_name": null, "docstring": "Test YOLO model results processing and output in various formats.", "parameters": [ "model: str" ], "return_type": null, "decorators": [ "pytest.mark.parametrize('model', MODELS)" ], "complexity_score": 2, "dependencies": [ "contextlib", "csv", "urllib", "copy.copy", "pathlib.Path", "cv2", "numpy", "pytest", "torch", "PIL.Image", "tests.CFG", "tests.MODEL", "tests.MODELS", "tests.SOURCE", "tests.SOURCES_LIST", "tests.TASK_MODEL_DATA", "tests.TMP", "ultralytics.RTDETR", "ultralytics.YOLO", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASKS", "ultralytics.data.build.load_inference_source", "ultralytics.data.utils.check_det_dataset", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.DEFAULT_CFG_PATH", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.ONLINE", "ultralytics.utils.ROOT", "ultralytics.utils.WEIGHTS_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.YAML", "ultralytics.utils.checks", "ultralytics.utils.is_dir_writeable", "ultralytics.utils.is_github_action_running", "ultralytics.utils.downloads.download", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.nn.tasks.DetectionModel", "ultralytics.data.split.autosplit", "ultralytics.data.utils.HUBDatasetStats", "ultralytics.utils.downloads.zip_directory", "ultralytics.data.converter.coco80_to_coco91_class", "ultralytics.data.converter.convert_coco", "ultralytics.data.annotator.auto_annotate", "ultralytics.hub.utils.Events", "ultralytics.cfg.check_dict_alignment", "ultralytics.cfg.copy_default_cfg", "ultralytics.cfg.smart_value", "ultralytics.utils.get_git_branch", "ultralytics.utils.get_git_origin_url", "ultralytics.utils.get_ubuntu_version", "ultralytics.utils.is_github_action_running", "ultralytics.utils.benchmarks.ProfileModels", "ultralytics.nn.modules.conv.Conv", "ultralytics.utils.torch_utils.get_flops_with_torch_profiler", "ultralytics.utils.torch_utils.profile_ops", "ultralytics.utils.torch_utils.time_sync", "ultralytics.utils.ops.ltwh2xywh", "ultralytics.utils.ops.ltwh2xyxy", "ultralytics.utils.ops.make_divisible", "ultralytics.utils.ops.xywh2ltwh", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xywhn2xyxy", "ultralytics.utils.ops.xywhr2xyxyxyxy", "ultralytics.utils.ops.xyxy2ltwh", "ultralytics.utils.ops.xyxy2xywh", "ultralytics.utils.ops.xyxy2xywhn", "ultralytics.utils.ops.xyxyxyxy2xywhr", "ultralytics.utils.files.file_age", "ultralytics.utils.files.file_date", "ultralytics.utils.files.get_latest_run", "ultralytics.utils.files.spaces_in_path", "unittest.mock.MagicMock", "unittest.mock.patch", "ultralytics.utils.patches.torch_save", "ultralytics.nn.modules.conv.CBAM", "ultralytics.nn.modules.conv.Conv2", "ultralytics.nn.modules.conv.ConvTranspose", "ultralytics.nn.modules.conv.DWConvTranspose2d", "ultralytics.nn.modules.conv.Focus", "ultralytics.nn.modules.block.C1", "ultralytics.nn.modules.block.C3TR", "ultralytics.nn.modules.block.BottleneckCSP", "ultralytics.nn.modules.block.C3Ghost", "ultralytics.nn.modules.block.C3x", "ultralytics.hub.export_fmts_hub", "ultralytics.hub.logout", "ultralytics.hub.utils.smart_request", "ultralytics.data.augment.classify_augmentations", "ultralytics.models.yolo.world.train_world.WorldTrainerFromScratch", "numpy", "ultralytics.YOLOE", "ultralytics.models.yolo.yoloe.YOLOEVPSegPredictor", "ultralytics.models.yolo.yoloe.YOLOEPESegTrainer" ], "chunk_id": "function_test_results_fafc9ae6" }, { "content": "def test_labels_and_crops():\n \"\"\"Test output from prediction args for saving YOLO detection labels and crops.\"\"\"\n imgs = [SOURCE, ASSETS / \"zidane.jpg\"]\n results = YOLO(WEIGHTS_DIR / \"yolo11n.pt\")(imgs, imgsz=160, save_txt=True, save_crop=True)\n save_path = Path(results[0].save_dir)\n for r in results:\n im_name = Path(r.path).stem\n cls_idxs = r.boxes.cls.int().tolist()\n # Check correct detections\n assert cls_idxs == ([0, 7, 0, 0] if r.path.endswith(\"bus.jpg\") else [0, 0, 0]) # bus.jpg and zidane.jpg classes\n # Check label path\n labels = save_path / f\"labels/{im_name}.txt\"\n assert labels.exists()\n # Check detections match label count\n assert len(r.boxes.data) == len([line for line in labels.read_text().splitlines() if line])\n # Check crops path and files\n crop_dirs = list((save_path / \"crops\").iterdir())\n crop_files = [f for p in crop_dirs for f in p.glob(\"*\")]\n # Crop directories match detections\n assert all(r.names.get(c) in {d.name for d in crop_dirs} for c in cls_idxs)\n # Same number of crops as detections\n assert len([f for f in crop_files if im_name in f.name]) == len(r.boxes.data)", "chunk_type": "function", "name": "test_labels_and_crops", "file_path": "ultralytics\\tests\\test_python.py", "start_line": 306, "end_line": 327, "start_col": 0, "end_col": 85, "parent_name": null, "docstring": "Test output from prediction args for saving YOLO detection labels and crops.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 7, "dependencies": [ "contextlib", "csv", "urllib", "copy.copy", "pathlib.Path", "cv2", "numpy", "pytest", "torch", "PIL.Image", "tests.CFG", "tests.MODEL", "tests.MODELS", "tests.SOURCE", "tests.SOURCES_LIST", "tests.TASK_MODEL_DATA", "tests.TMP", "ultralytics.RTDETR", "ultralytics.YOLO", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASKS", "ultralytics.data.build.load_inference_source", "ultralytics.data.utils.check_det_dataset", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.DEFAULT_CFG_PATH", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.ONLINE", "ultralytics.utils.ROOT", "ultralytics.utils.WEIGHTS_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.YAML", "ultralytics.utils.checks", "ultralytics.utils.is_dir_writeable", "ultralytics.utils.is_github_action_running", "ultralytics.utils.downloads.download", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.nn.tasks.DetectionModel", "ultralytics.data.split.autosplit", "ultralytics.data.utils.HUBDatasetStats", "ultralytics.utils.downloads.zip_directory", "ultralytics.data.converter.coco80_to_coco91_class", "ultralytics.data.converter.convert_coco", "ultralytics.data.annotator.auto_annotate", "ultralytics.hub.utils.Events", "ultralytics.cfg.check_dict_alignment", "ultralytics.cfg.copy_default_cfg", "ultralytics.cfg.smart_value", "ultralytics.utils.get_git_branch", "ultralytics.utils.get_git_origin_url", "ultralytics.utils.get_ubuntu_version", "ultralytics.utils.is_github_action_running", "ultralytics.utils.benchmarks.ProfileModels", "ultralytics.nn.modules.conv.Conv", "ultralytics.utils.torch_utils.get_flops_with_torch_profiler", "ultralytics.utils.torch_utils.profile_ops", "ultralytics.utils.torch_utils.time_sync", "ultralytics.utils.ops.ltwh2xywh", "ultralytics.utils.ops.ltwh2xyxy", "ultralytics.utils.ops.make_divisible", "ultralytics.utils.ops.xywh2ltwh", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xywhn2xyxy", "ultralytics.utils.ops.xywhr2xyxyxyxy", "ultralytics.utils.ops.xyxy2ltwh", "ultralytics.utils.ops.xyxy2xywh", "ultralytics.utils.ops.xyxy2xywhn", "ultralytics.utils.ops.xyxyxyxy2xywhr", "ultralytics.utils.files.file_age", "ultralytics.utils.files.file_date", "ultralytics.utils.files.get_latest_run", "ultralytics.utils.files.spaces_in_path", "unittest.mock.MagicMock", "unittest.mock.patch", "ultralytics.utils.patches.torch_save", "ultralytics.nn.modules.conv.CBAM", "ultralytics.nn.modules.conv.Conv2", "ultralytics.nn.modules.conv.ConvTranspose", "ultralytics.nn.modules.conv.DWConvTranspose2d", "ultralytics.nn.modules.conv.Focus", "ultralytics.nn.modules.block.C1", "ultralytics.nn.modules.block.C3TR", "ultralytics.nn.modules.block.BottleneckCSP", "ultralytics.nn.modules.block.C3Ghost", "ultralytics.nn.modules.block.C3x", "ultralytics.hub.export_fmts_hub", "ultralytics.hub.logout", "ultralytics.hub.utils.smart_request", "ultralytics.data.augment.classify_augmentations", "ultralytics.models.yolo.world.train_world.WorldTrainerFromScratch", "numpy", "ultralytics.YOLOE", "ultralytics.models.yolo.yoloe.YOLOEVPSegPredictor", "ultralytics.models.yolo.yoloe.YOLOEPESegTrainer" ], "chunk_id": "function_test_labels_and_crops_291231b5" }, { "content": "def test_data_utils():\n \"\"\"Test utility functions in ultralytics/data/utils.py, including dataset stats and auto-splitting.\"\"\"\n from ultralytics.data.split import autosplit\n from ultralytics.data.utils import HUBDatasetStats\n from ultralytics.utils.downloads import zip_directory\n\n # from ultralytics.utils.files import WorkingDirectory\n # with WorkingDirectory(ROOT.parent / 'tests'):\n\n for task in TASKS:\n file = Path(TASK2DATA[task]).with_suffix(\".zip\") # i.e. coco8.zip\n download(f\"https://github.com/ultralytics/hub/raw/main/example_datasets/{file}\", unzip=False, dir=TMP)\n stats = HUBDatasetStats(TMP / file, task=task)\n stats.get_json(save=True)\n stats.process_images()\n\n autosplit(TMP / \"coco8\")\n zip_directory(TMP / \"coco8/images/val\") # zip", "chunk_type": "function", "name": "test_data_utils", "file_path": "ultralytics\\tests\\test_python.py", "start_line": 331, "end_line": 348, "start_col": 0, "end_col": 43, "parent_name": null, "docstring": "Test utility functions in ultralytics/data/utils.py, including dataset stats and auto-splitting.", "parameters": [], "return_type": null, "decorators": [ "pytest.mark.skipif(not ONLINE, reason='environment is offline')" ], "complexity_score": 2, "dependencies": [ "contextlib", "csv", "urllib", "copy.copy", "pathlib.Path", "cv2", "numpy", "pytest", "torch", "PIL.Image", "tests.CFG", "tests.MODEL", "tests.MODELS", "tests.SOURCE", "tests.SOURCES_LIST", "tests.TASK_MODEL_DATA", "tests.TMP", "ultralytics.RTDETR", "ultralytics.YOLO", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASKS", "ultralytics.data.build.load_inference_source", "ultralytics.data.utils.check_det_dataset", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.DEFAULT_CFG_PATH", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.ONLINE", "ultralytics.utils.ROOT", "ultralytics.utils.WEIGHTS_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.YAML", "ultralytics.utils.checks", "ultralytics.utils.is_dir_writeable", "ultralytics.utils.is_github_action_running", "ultralytics.utils.downloads.download", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.nn.tasks.DetectionModel", "ultralytics.data.split.autosplit", "ultralytics.data.utils.HUBDatasetStats", "ultralytics.utils.downloads.zip_directory", "ultralytics.data.converter.coco80_to_coco91_class", "ultralytics.data.converter.convert_coco", "ultralytics.data.annotator.auto_annotate", "ultralytics.hub.utils.Events", "ultralytics.cfg.check_dict_alignment", "ultralytics.cfg.copy_default_cfg", "ultralytics.cfg.smart_value", "ultralytics.utils.get_git_branch", "ultralytics.utils.get_git_origin_url", "ultralytics.utils.get_ubuntu_version", "ultralytics.utils.is_github_action_running", "ultralytics.utils.benchmarks.ProfileModels", "ultralytics.nn.modules.conv.Conv", "ultralytics.utils.torch_utils.get_flops_with_torch_profiler", "ultralytics.utils.torch_utils.profile_ops", "ultralytics.utils.torch_utils.time_sync", "ultralytics.utils.ops.ltwh2xywh", "ultralytics.utils.ops.ltwh2xyxy", "ultralytics.utils.ops.make_divisible", "ultralytics.utils.ops.xywh2ltwh", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xywhn2xyxy", "ultralytics.utils.ops.xywhr2xyxyxyxy", "ultralytics.utils.ops.xyxy2ltwh", "ultralytics.utils.ops.xyxy2xywh", "ultralytics.utils.ops.xyxy2xywhn", "ultralytics.utils.ops.xyxyxyxy2xywhr", "ultralytics.utils.files.file_age", "ultralytics.utils.files.file_date", "ultralytics.utils.files.get_latest_run", "ultralytics.utils.files.spaces_in_path", "unittest.mock.MagicMock", "unittest.mock.patch", "ultralytics.utils.patches.torch_save", "ultralytics.nn.modules.conv.CBAM", "ultralytics.nn.modules.conv.Conv2", "ultralytics.nn.modules.conv.ConvTranspose", "ultralytics.nn.modules.conv.DWConvTranspose2d", "ultralytics.nn.modules.conv.Focus", "ultralytics.nn.modules.block.C1", "ultralytics.nn.modules.block.C3TR", "ultralytics.nn.modules.block.BottleneckCSP", "ultralytics.nn.modules.block.C3Ghost", "ultralytics.nn.modules.block.C3x", "ultralytics.hub.export_fmts_hub", "ultralytics.hub.logout", "ultralytics.hub.utils.smart_request", "ultralytics.data.augment.classify_augmentations", "ultralytics.models.yolo.world.train_world.WorldTrainerFromScratch", "numpy", "ultralytics.YOLOE", "ultralytics.models.yolo.yoloe.YOLOEVPSegPredictor", "ultralytics.models.yolo.yoloe.YOLOEPESegTrainer" ], "chunk_id": "function_test_data_utils_6ef1ca22" }, { "content": "def test_data_converter():\n \"\"\"Test dataset conversion functions from COCO to YOLO format and class mappings.\"\"\"\n from ultralytics.data.converter import coco80_to_coco91_class, convert_coco\n\n file = \"instances_val2017.json\"\n download(f\"https://github.com/ultralytics/assets/releases/download/v0.0.0/{file}\", dir=TMP)\n convert_coco(labels_dir=TMP, save_dir=TMP / \"yolo_labels\", use_segments=True, use_keypoints=False, cls91to80=True)\n coco80_to_coco91_class()", "chunk_type": "function", "name": "test_data_converter", "file_path": "ultralytics\\tests\\test_python.py", "start_line": 352, "end_line": 359, "start_col": 0, "end_col": 28, "parent_name": null, "docstring": "Test dataset conversion functions from COCO to YOLO format and class mappings.", "parameters": [], "return_type": null, "decorators": [ "pytest.mark.skipif(not ONLINE, reason='environment is offline')" ], "complexity_score": 1, "dependencies": [ "contextlib", "csv", "urllib", "copy.copy", "pathlib.Path", "cv2", "numpy", "pytest", "torch", "PIL.Image", "tests.CFG", "tests.MODEL", "tests.MODELS", "tests.SOURCE", "tests.SOURCES_LIST", "tests.TASK_MODEL_DATA", "tests.TMP", "ultralytics.RTDETR", "ultralytics.YOLO", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASKS", "ultralytics.data.build.load_inference_source", "ultralytics.data.utils.check_det_dataset", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.DEFAULT_CFG_PATH", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.ONLINE", "ultralytics.utils.ROOT", "ultralytics.utils.WEIGHTS_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.YAML", "ultralytics.utils.checks", "ultralytics.utils.is_dir_writeable", "ultralytics.utils.is_github_action_running", "ultralytics.utils.downloads.download", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.nn.tasks.DetectionModel", "ultralytics.data.split.autosplit", "ultralytics.data.utils.HUBDatasetStats", "ultralytics.utils.downloads.zip_directory", "ultralytics.data.converter.coco80_to_coco91_class", "ultralytics.data.converter.convert_coco", "ultralytics.data.annotator.auto_annotate", "ultralytics.hub.utils.Events", "ultralytics.cfg.check_dict_alignment", "ultralytics.cfg.copy_default_cfg", "ultralytics.cfg.smart_value", "ultralytics.utils.get_git_branch", "ultralytics.utils.get_git_origin_url", "ultralytics.utils.get_ubuntu_version", "ultralytics.utils.is_github_action_running", "ultralytics.utils.benchmarks.ProfileModels", "ultralytics.nn.modules.conv.Conv", "ultralytics.utils.torch_utils.get_flops_with_torch_profiler", "ultralytics.utils.torch_utils.profile_ops", "ultralytics.utils.torch_utils.time_sync", "ultralytics.utils.ops.ltwh2xywh", "ultralytics.utils.ops.ltwh2xyxy", "ultralytics.utils.ops.make_divisible", "ultralytics.utils.ops.xywh2ltwh", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xywhn2xyxy", "ultralytics.utils.ops.xywhr2xyxyxyxy", "ultralytics.utils.ops.xyxy2ltwh", "ultralytics.utils.ops.xyxy2xywh", "ultralytics.utils.ops.xyxy2xywhn", "ultralytics.utils.ops.xyxyxyxy2xywhr", "ultralytics.utils.files.file_age", "ultralytics.utils.files.file_date", "ultralytics.utils.files.get_latest_run", "ultralytics.utils.files.spaces_in_path", "unittest.mock.MagicMock", "unittest.mock.patch", "ultralytics.utils.patches.torch_save", "ultralytics.nn.modules.conv.CBAM", "ultralytics.nn.modules.conv.Conv2", "ultralytics.nn.modules.conv.ConvTranspose", "ultralytics.nn.modules.conv.DWConvTranspose2d", "ultralytics.nn.modules.conv.Focus", "ultralytics.nn.modules.block.C1", "ultralytics.nn.modules.block.C3TR", "ultralytics.nn.modules.block.BottleneckCSP", "ultralytics.nn.modules.block.C3Ghost", "ultralytics.nn.modules.block.C3x", "ultralytics.hub.export_fmts_hub", "ultralytics.hub.logout", "ultralytics.hub.utils.smart_request", "ultralytics.data.augment.classify_augmentations", "ultralytics.models.yolo.world.train_world.WorldTrainerFromScratch", "numpy", "ultralytics.YOLOE", "ultralytics.models.yolo.yoloe.YOLOEVPSegPredictor", "ultralytics.models.yolo.yoloe.YOLOEPESegTrainer" ], "chunk_id": "function_test_data_converter_4d34d536" }, { "content": "def test_data_annotator():\n \"\"\"Test automatic annotation of data using detection and segmentation models.\"\"\"\n from ultralytics.data.annotator import auto_annotate\n\n auto_annotate(\n ASSETS,\n det_model=WEIGHTS_DIR / \"yolo11n.pt\",\n sam_model=WEIGHTS_DIR / \"mobile_sam.pt\",\n output_dir=TMP / \"auto_annotate_labels\",\n )", "chunk_type": "function", "name": "test_data_annotator", "file_path": "ultralytics\\tests\\test_python.py", "start_line": 362, "end_line": 371, "start_col": 0, "end_col": 5, "parent_name": null, "docstring": "Test automatic annotation of data using detection and segmentation models.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "contextlib", "csv", "urllib", "copy.copy", "pathlib.Path", "cv2", "numpy", "pytest", "torch", "PIL.Image", "tests.CFG", "tests.MODEL", "tests.MODELS", "tests.SOURCE", "tests.SOURCES_LIST", "tests.TASK_MODEL_DATA", "tests.TMP", "ultralytics.RTDETR", "ultralytics.YOLO", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASKS", "ultralytics.data.build.load_inference_source", "ultralytics.data.utils.check_det_dataset", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.DEFAULT_CFG_PATH", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.ONLINE", "ultralytics.utils.ROOT", "ultralytics.utils.WEIGHTS_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.YAML", "ultralytics.utils.checks", "ultralytics.utils.is_dir_writeable", "ultralytics.utils.is_github_action_running", "ultralytics.utils.downloads.download", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.nn.tasks.DetectionModel", "ultralytics.data.split.autosplit", "ultralytics.data.utils.HUBDatasetStats", "ultralytics.utils.downloads.zip_directory", "ultralytics.data.converter.coco80_to_coco91_class", "ultralytics.data.converter.convert_coco", "ultralytics.data.annotator.auto_annotate", "ultralytics.hub.utils.Events", "ultralytics.cfg.check_dict_alignment", "ultralytics.cfg.copy_default_cfg", "ultralytics.cfg.smart_value", "ultralytics.utils.get_git_branch", "ultralytics.utils.get_git_origin_url", "ultralytics.utils.get_ubuntu_version", "ultralytics.utils.is_github_action_running", "ultralytics.utils.benchmarks.ProfileModels", "ultralytics.nn.modules.conv.Conv", "ultralytics.utils.torch_utils.get_flops_with_torch_profiler", "ultralytics.utils.torch_utils.profile_ops", "ultralytics.utils.torch_utils.time_sync", "ultralytics.utils.ops.ltwh2xywh", "ultralytics.utils.ops.ltwh2xyxy", "ultralytics.utils.ops.make_divisible", "ultralytics.utils.ops.xywh2ltwh", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xywhn2xyxy", "ultralytics.utils.ops.xywhr2xyxyxyxy", "ultralytics.utils.ops.xyxy2ltwh", "ultralytics.utils.ops.xyxy2xywh", "ultralytics.utils.ops.xyxy2xywhn", "ultralytics.utils.ops.xyxyxyxy2xywhr", "ultralytics.utils.files.file_age", "ultralytics.utils.files.file_date", "ultralytics.utils.files.get_latest_run", "ultralytics.utils.files.spaces_in_path", "unittest.mock.MagicMock", "unittest.mock.patch", "ultralytics.utils.patches.torch_save", "ultralytics.nn.modules.conv.CBAM", "ultralytics.nn.modules.conv.Conv2", "ultralytics.nn.modules.conv.ConvTranspose", "ultralytics.nn.modules.conv.DWConvTranspose2d", "ultralytics.nn.modules.conv.Focus", "ultralytics.nn.modules.block.C1", "ultralytics.nn.modules.block.C3TR", "ultralytics.nn.modules.block.BottleneckCSP", "ultralytics.nn.modules.block.C3Ghost", "ultralytics.nn.modules.block.C3x", "ultralytics.hub.export_fmts_hub", "ultralytics.hub.logout", "ultralytics.hub.utils.smart_request", "ultralytics.data.augment.classify_augmentations", "ultralytics.models.yolo.world.train_world.WorldTrainerFromScratch", "numpy", "ultralytics.YOLOE", "ultralytics.models.yolo.yoloe.YOLOEVPSegPredictor", "ultralytics.models.yolo.yoloe.YOLOEPESegTrainer" ], "chunk_id": "function_test_data_annotator_16eec09d" }, { "content": "def test_events():\n \"\"\"Test event sending functionality.\"\"\"\n from ultralytics.hub.utils import Events\n\n events = Events()\n events.enabled = True\n cfg = copy(DEFAULT_CFG) # does not require deepcopy\n cfg.mode = \"test\"\n events(cfg)", "chunk_type": "function", "name": "test_events", "file_path": "ultralytics\\tests\\test_python.py", "start_line": 374, "end_line": 382, "start_col": 0, "end_col": 15, "parent_name": null, "docstring": "Test event sending functionality.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "contextlib", "csv", "urllib", "copy.copy", "pathlib.Path", "cv2", "numpy", "pytest", "torch", "PIL.Image", "tests.CFG", "tests.MODEL", "tests.MODELS", "tests.SOURCE", "tests.SOURCES_LIST", "tests.TASK_MODEL_DATA", "tests.TMP", "ultralytics.RTDETR", "ultralytics.YOLO", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASKS", "ultralytics.data.build.load_inference_source", "ultralytics.data.utils.check_det_dataset", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.DEFAULT_CFG_PATH", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.ONLINE", "ultralytics.utils.ROOT", "ultralytics.utils.WEIGHTS_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.YAML", "ultralytics.utils.checks", "ultralytics.utils.is_dir_writeable", "ultralytics.utils.is_github_action_running", "ultralytics.utils.downloads.download", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.nn.tasks.DetectionModel", "ultralytics.data.split.autosplit", "ultralytics.data.utils.HUBDatasetStats", "ultralytics.utils.downloads.zip_directory", "ultralytics.data.converter.coco80_to_coco91_class", "ultralytics.data.converter.convert_coco", "ultralytics.data.annotator.auto_annotate", "ultralytics.hub.utils.Events", "ultralytics.cfg.check_dict_alignment", "ultralytics.cfg.copy_default_cfg", "ultralytics.cfg.smart_value", "ultralytics.utils.get_git_branch", "ultralytics.utils.get_git_origin_url", "ultralytics.utils.get_ubuntu_version", "ultralytics.utils.is_github_action_running", "ultralytics.utils.benchmarks.ProfileModels", "ultralytics.nn.modules.conv.Conv", "ultralytics.utils.torch_utils.get_flops_with_torch_profiler", "ultralytics.utils.torch_utils.profile_ops", "ultralytics.utils.torch_utils.time_sync", "ultralytics.utils.ops.ltwh2xywh", "ultralytics.utils.ops.ltwh2xyxy", "ultralytics.utils.ops.make_divisible", "ultralytics.utils.ops.xywh2ltwh", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xywhn2xyxy", "ultralytics.utils.ops.xywhr2xyxyxyxy", "ultralytics.utils.ops.xyxy2ltwh", "ultralytics.utils.ops.xyxy2xywh", "ultralytics.utils.ops.xyxy2xywhn", "ultralytics.utils.ops.xyxyxyxy2xywhr", "ultralytics.utils.files.file_age", "ultralytics.utils.files.file_date", "ultralytics.utils.files.get_latest_run", "ultralytics.utils.files.spaces_in_path", "unittest.mock.MagicMock", "unittest.mock.patch", "ultralytics.utils.patches.torch_save", "ultralytics.nn.modules.conv.CBAM", "ultralytics.nn.modules.conv.Conv2", "ultralytics.nn.modules.conv.ConvTranspose", "ultralytics.nn.modules.conv.DWConvTranspose2d", "ultralytics.nn.modules.conv.Focus", "ultralytics.nn.modules.block.C1", "ultralytics.nn.modules.block.C3TR", "ultralytics.nn.modules.block.BottleneckCSP", "ultralytics.nn.modules.block.C3Ghost", "ultralytics.nn.modules.block.C3x", "ultralytics.hub.export_fmts_hub", "ultralytics.hub.logout", "ultralytics.hub.utils.smart_request", "ultralytics.data.augment.classify_augmentations", "ultralytics.models.yolo.world.train_world.WorldTrainerFromScratch", "numpy", "ultralytics.YOLOE", "ultralytics.models.yolo.yoloe.YOLOEVPSegPredictor", "ultralytics.models.yolo.yoloe.YOLOEPESegTrainer" ], "chunk_id": "function_test_events_cd418bf5" }, { "content": "def test_cfg_init():\n \"\"\"Test configuration initialization utilities from the 'ultralytics.cfg' module.\"\"\"\n from ultralytics.cfg import check_dict_alignment, copy_default_cfg, smart_value\n\n with contextlib.suppress(SyntaxError):\n check_dict_alignment({\"a\": 1}, {\"b\": 2})\n copy_default_cfg()\n (Path.cwd() / DEFAULT_CFG_PATH.name.replace(\".yaml\", \"_copy.yaml\")).unlink(missing_ok=False)\n [smart_value(x) for x in {\"none\", \"true\", \"false\"}]", "chunk_type": "function", "name": "test_cfg_init", "file_path": "ultralytics\\tests\\test_python.py", "start_line": 385, "end_line": 393, "start_col": 0, "end_col": 55, "parent_name": null, "docstring": "Test configuration initialization utilities from the 'ultralytics.cfg' module.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 2, "dependencies": [ "contextlib", "csv", "urllib", "copy.copy", "pathlib.Path", "cv2", "numpy", "pytest", "torch", "PIL.Image", "tests.CFG", "tests.MODEL", "tests.MODELS", "tests.SOURCE", "tests.SOURCES_LIST", "tests.TASK_MODEL_DATA", "tests.TMP", "ultralytics.RTDETR", "ultralytics.YOLO", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASKS", "ultralytics.data.build.load_inference_source", "ultralytics.data.utils.check_det_dataset", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.DEFAULT_CFG_PATH", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.ONLINE", "ultralytics.utils.ROOT", "ultralytics.utils.WEIGHTS_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.YAML", "ultralytics.utils.checks", "ultralytics.utils.is_dir_writeable", "ultralytics.utils.is_github_action_running", "ultralytics.utils.downloads.download", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.nn.tasks.DetectionModel", "ultralytics.data.split.autosplit", "ultralytics.data.utils.HUBDatasetStats", "ultralytics.utils.downloads.zip_directory", "ultralytics.data.converter.coco80_to_coco91_class", "ultralytics.data.converter.convert_coco", "ultralytics.data.annotator.auto_annotate", "ultralytics.hub.utils.Events", "ultralytics.cfg.check_dict_alignment", "ultralytics.cfg.copy_default_cfg", "ultralytics.cfg.smart_value", "ultralytics.utils.get_git_branch", "ultralytics.utils.get_git_origin_url", "ultralytics.utils.get_ubuntu_version", "ultralytics.utils.is_github_action_running", "ultralytics.utils.benchmarks.ProfileModels", "ultralytics.nn.modules.conv.Conv", "ultralytics.utils.torch_utils.get_flops_with_torch_profiler", "ultralytics.utils.torch_utils.profile_ops", "ultralytics.utils.torch_utils.time_sync", "ultralytics.utils.ops.ltwh2xywh", "ultralytics.utils.ops.ltwh2xyxy", "ultralytics.utils.ops.make_divisible", "ultralytics.utils.ops.xywh2ltwh", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xywhn2xyxy", "ultralytics.utils.ops.xywhr2xyxyxyxy", "ultralytics.utils.ops.xyxy2ltwh", "ultralytics.utils.ops.xyxy2xywh", "ultralytics.utils.ops.xyxy2xywhn", "ultralytics.utils.ops.xyxyxyxy2xywhr", "ultralytics.utils.files.file_age", "ultralytics.utils.files.file_date", "ultralytics.utils.files.get_latest_run", "ultralytics.utils.files.spaces_in_path", "unittest.mock.MagicMock", "unittest.mock.patch", "ultralytics.utils.patches.torch_save", "ultralytics.nn.modules.conv.CBAM", "ultralytics.nn.modules.conv.Conv2", "ultralytics.nn.modules.conv.ConvTranspose", "ultralytics.nn.modules.conv.DWConvTranspose2d", "ultralytics.nn.modules.conv.Focus", "ultralytics.nn.modules.block.C1", "ultralytics.nn.modules.block.C3TR", "ultralytics.nn.modules.block.BottleneckCSP", "ultralytics.nn.modules.block.C3Ghost", "ultralytics.nn.modules.block.C3x", "ultralytics.hub.export_fmts_hub", "ultralytics.hub.logout", "ultralytics.hub.utils.smart_request", "ultralytics.data.augment.classify_augmentations", "ultralytics.models.yolo.world.train_world.WorldTrainerFromScratch", "numpy", "ultralytics.YOLOE", "ultralytics.models.yolo.yoloe.YOLOEVPSegPredictor", "ultralytics.models.yolo.yoloe.YOLOEPESegTrainer" ], "chunk_id": "function_test_cfg_init_d18c9271" }, { "content": "def test_utils_init():\n \"\"\"Test initialization utilities in the Ultralytics library.\"\"\"\n from ultralytics.utils import get_git_branch, get_git_origin_url, get_ubuntu_version, is_github_action_running\n\n get_ubuntu_version()\n is_github_action_running()\n get_git_origin_url()\n get_git_branch()", "chunk_type": "function", "name": "test_utils_init", "file_path": "ultralytics\\tests\\test_python.py", "start_line": 396, "end_line": 403, "start_col": 0, "end_col": 20, "parent_name": null, "docstring": "Test initialization utilities in the Ultralytics library.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "contextlib", "csv", "urllib", "copy.copy", "pathlib.Path", "cv2", "numpy", "pytest", "torch", "PIL.Image", "tests.CFG", "tests.MODEL", "tests.MODELS", "tests.SOURCE", "tests.SOURCES_LIST", "tests.TASK_MODEL_DATA", "tests.TMP", "ultralytics.RTDETR", "ultralytics.YOLO", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASKS", "ultralytics.data.build.load_inference_source", "ultralytics.data.utils.check_det_dataset", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.DEFAULT_CFG_PATH", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.ONLINE", "ultralytics.utils.ROOT", "ultralytics.utils.WEIGHTS_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.YAML", "ultralytics.utils.checks", "ultralytics.utils.is_dir_writeable", "ultralytics.utils.is_github_action_running", "ultralytics.utils.downloads.download", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.nn.tasks.DetectionModel", "ultralytics.data.split.autosplit", "ultralytics.data.utils.HUBDatasetStats", "ultralytics.utils.downloads.zip_directory", "ultralytics.data.converter.coco80_to_coco91_class", "ultralytics.data.converter.convert_coco", "ultralytics.data.annotator.auto_annotate", "ultralytics.hub.utils.Events", "ultralytics.cfg.check_dict_alignment", "ultralytics.cfg.copy_default_cfg", "ultralytics.cfg.smart_value", "ultralytics.utils.get_git_branch", "ultralytics.utils.get_git_origin_url", "ultralytics.utils.get_ubuntu_version", "ultralytics.utils.is_github_action_running", "ultralytics.utils.benchmarks.ProfileModels", "ultralytics.nn.modules.conv.Conv", "ultralytics.utils.torch_utils.get_flops_with_torch_profiler", "ultralytics.utils.torch_utils.profile_ops", "ultralytics.utils.torch_utils.time_sync", "ultralytics.utils.ops.ltwh2xywh", "ultralytics.utils.ops.ltwh2xyxy", "ultralytics.utils.ops.make_divisible", "ultralytics.utils.ops.xywh2ltwh", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xywhn2xyxy", "ultralytics.utils.ops.xywhr2xyxyxyxy", "ultralytics.utils.ops.xyxy2ltwh", "ultralytics.utils.ops.xyxy2xywh", "ultralytics.utils.ops.xyxy2xywhn", "ultralytics.utils.ops.xyxyxyxy2xywhr", "ultralytics.utils.files.file_age", "ultralytics.utils.files.file_date", "ultralytics.utils.files.get_latest_run", "ultralytics.utils.files.spaces_in_path", "unittest.mock.MagicMock", "unittest.mock.patch", "ultralytics.utils.patches.torch_save", "ultralytics.nn.modules.conv.CBAM", "ultralytics.nn.modules.conv.Conv2", "ultralytics.nn.modules.conv.ConvTranspose", "ultralytics.nn.modules.conv.DWConvTranspose2d", "ultralytics.nn.modules.conv.Focus", "ultralytics.nn.modules.block.C1", "ultralytics.nn.modules.block.C3TR", "ultralytics.nn.modules.block.BottleneckCSP", "ultralytics.nn.modules.block.C3Ghost", "ultralytics.nn.modules.block.C3x", "ultralytics.hub.export_fmts_hub", "ultralytics.hub.logout", "ultralytics.hub.utils.smart_request", "ultralytics.data.augment.classify_augmentations", "ultralytics.models.yolo.world.train_world.WorldTrainerFromScratch", "numpy", "ultralytics.YOLOE", "ultralytics.models.yolo.yoloe.YOLOEVPSegPredictor", "ultralytics.models.yolo.yoloe.YOLOEPESegTrainer" ], "chunk_id": "function_test_utils_init_624db119" }, { "content": "def test_utils_checks():\n \"\"\"Test various utility checks for filenames, git status, requirements, image sizes, and versions.\"\"\"\n checks.check_yolov5u_filename(\"yolov5n.pt\")\n checks.git_describe(ROOT)\n checks.check_requirements() # check requirements.txt\n checks.check_imgsz([600, 600], max_dim=1)\n checks.check_imshow(warn=True)\n checks.check_version(\"ultralytics\", \"8.0.0\")\n checks.print_args()", "chunk_type": "function", "name": "test_utils_checks", "file_path": "ultralytics\\tests\\test_python.py", "start_line": 406, "end_line": 414, "start_col": 0, "end_col": 23, "parent_name": null, "docstring": "Test various utility checks for filenames, git status, requirements, image sizes, and versions.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "contextlib", "csv", "urllib", "copy.copy", "pathlib.Path", "cv2", "numpy", "pytest", "torch", "PIL.Image", "tests.CFG", "tests.MODEL", "tests.MODELS", "tests.SOURCE", "tests.SOURCES_LIST", "tests.TASK_MODEL_DATA", "tests.TMP", "ultralytics.RTDETR", "ultralytics.YOLO", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASKS", "ultralytics.data.build.load_inference_source", "ultralytics.data.utils.check_det_dataset", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.DEFAULT_CFG_PATH", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.ONLINE", "ultralytics.utils.ROOT", "ultralytics.utils.WEIGHTS_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.YAML", "ultralytics.utils.checks", "ultralytics.utils.is_dir_writeable", "ultralytics.utils.is_github_action_running", "ultralytics.utils.downloads.download", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.nn.tasks.DetectionModel", "ultralytics.data.split.autosplit", "ultralytics.data.utils.HUBDatasetStats", "ultralytics.utils.downloads.zip_directory", "ultralytics.data.converter.coco80_to_coco91_class", "ultralytics.data.converter.convert_coco", "ultralytics.data.annotator.auto_annotate", "ultralytics.hub.utils.Events", "ultralytics.cfg.check_dict_alignment", "ultralytics.cfg.copy_default_cfg", "ultralytics.cfg.smart_value", "ultralytics.utils.get_git_branch", "ultralytics.utils.get_git_origin_url", "ultralytics.utils.get_ubuntu_version", "ultralytics.utils.is_github_action_running", "ultralytics.utils.benchmarks.ProfileModels", "ultralytics.nn.modules.conv.Conv", "ultralytics.utils.torch_utils.get_flops_with_torch_profiler", "ultralytics.utils.torch_utils.profile_ops", "ultralytics.utils.torch_utils.time_sync", "ultralytics.utils.ops.ltwh2xywh", "ultralytics.utils.ops.ltwh2xyxy", "ultralytics.utils.ops.make_divisible", "ultralytics.utils.ops.xywh2ltwh", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xywhn2xyxy", "ultralytics.utils.ops.xywhr2xyxyxyxy", "ultralytics.utils.ops.xyxy2ltwh", "ultralytics.utils.ops.xyxy2xywh", "ultralytics.utils.ops.xyxy2xywhn", "ultralytics.utils.ops.xyxyxyxy2xywhr", "ultralytics.utils.files.file_age", "ultralytics.utils.files.file_date", "ultralytics.utils.files.get_latest_run", "ultralytics.utils.files.spaces_in_path", "unittest.mock.MagicMock", "unittest.mock.patch", "ultralytics.utils.patches.torch_save", "ultralytics.nn.modules.conv.CBAM", "ultralytics.nn.modules.conv.Conv2", "ultralytics.nn.modules.conv.ConvTranspose", "ultralytics.nn.modules.conv.DWConvTranspose2d", "ultralytics.nn.modules.conv.Focus", "ultralytics.nn.modules.block.C1", "ultralytics.nn.modules.block.C3TR", "ultralytics.nn.modules.block.BottleneckCSP", "ultralytics.nn.modules.block.C3Ghost", "ultralytics.nn.modules.block.C3x", "ultralytics.hub.export_fmts_hub", "ultralytics.hub.logout", "ultralytics.hub.utils.smart_request", "ultralytics.data.augment.classify_augmentations", "ultralytics.models.yolo.world.train_world.WorldTrainerFromScratch", "numpy", "ultralytics.YOLOE", "ultralytics.models.yolo.yoloe.YOLOEVPSegPredictor", "ultralytics.models.yolo.yoloe.YOLOEPESegTrainer" ], "chunk_id": "function_test_utils_checks_afb148ef" }, { "content": "def test_utils_benchmarks():\n \"\"\"Benchmark model performance using 'ProfileModels' from 'ultralytics.utils.benchmarks'.\"\"\"\n from ultralytics.utils.benchmarks import ProfileModels\n\n ProfileModels([\"yolo11n.yaml\"], imgsz=32, min_time=1, num_timed_runs=3, num_warmup_runs=1).run()", "chunk_type": "function", "name": "test_utils_benchmarks", "file_path": "ultralytics\\tests\\test_python.py", "start_line": 418, "end_line": 422, "start_col": 0, "end_col": 100, "parent_name": null, "docstring": "Benchmark model performance using 'ProfileModels' from 'ultralytics.utils.benchmarks'.", "parameters": [], "return_type": null, "decorators": [ "pytest.mark.skipif(WINDOWS, reason='Windows profiling is extremely slow (cause unknown)')" ], "complexity_score": 1, "dependencies": [ "contextlib", "csv", "urllib", "copy.copy", "pathlib.Path", "cv2", "numpy", "pytest", "torch", "PIL.Image", "tests.CFG", "tests.MODEL", "tests.MODELS", "tests.SOURCE", "tests.SOURCES_LIST", "tests.TASK_MODEL_DATA", "tests.TMP", "ultralytics.RTDETR", "ultralytics.YOLO", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASKS", "ultralytics.data.build.load_inference_source", "ultralytics.data.utils.check_det_dataset", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.DEFAULT_CFG_PATH", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.ONLINE", "ultralytics.utils.ROOT", "ultralytics.utils.WEIGHTS_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.YAML", "ultralytics.utils.checks", "ultralytics.utils.is_dir_writeable", "ultralytics.utils.is_github_action_running", "ultralytics.utils.downloads.download", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.nn.tasks.DetectionModel", "ultralytics.data.split.autosplit", "ultralytics.data.utils.HUBDatasetStats", "ultralytics.utils.downloads.zip_directory", "ultralytics.data.converter.coco80_to_coco91_class", "ultralytics.data.converter.convert_coco", "ultralytics.data.annotator.auto_annotate", "ultralytics.hub.utils.Events", "ultralytics.cfg.check_dict_alignment", "ultralytics.cfg.copy_default_cfg", "ultralytics.cfg.smart_value", "ultralytics.utils.get_git_branch", "ultralytics.utils.get_git_origin_url", "ultralytics.utils.get_ubuntu_version", "ultralytics.utils.is_github_action_running", "ultralytics.utils.benchmarks.ProfileModels", "ultralytics.nn.modules.conv.Conv", "ultralytics.utils.torch_utils.get_flops_with_torch_profiler", "ultralytics.utils.torch_utils.profile_ops", "ultralytics.utils.torch_utils.time_sync", "ultralytics.utils.ops.ltwh2xywh", "ultralytics.utils.ops.ltwh2xyxy", "ultralytics.utils.ops.make_divisible", "ultralytics.utils.ops.xywh2ltwh", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xywhn2xyxy", "ultralytics.utils.ops.xywhr2xyxyxyxy", "ultralytics.utils.ops.xyxy2ltwh", "ultralytics.utils.ops.xyxy2xywh", "ultralytics.utils.ops.xyxy2xywhn", "ultralytics.utils.ops.xyxyxyxy2xywhr", "ultralytics.utils.files.file_age", "ultralytics.utils.files.file_date", "ultralytics.utils.files.get_latest_run", "ultralytics.utils.files.spaces_in_path", "unittest.mock.MagicMock", "unittest.mock.patch", "ultralytics.utils.patches.torch_save", "ultralytics.nn.modules.conv.CBAM", "ultralytics.nn.modules.conv.Conv2", "ultralytics.nn.modules.conv.ConvTranspose", "ultralytics.nn.modules.conv.DWConvTranspose2d", "ultralytics.nn.modules.conv.Focus", "ultralytics.nn.modules.block.C1", "ultralytics.nn.modules.block.C3TR", "ultralytics.nn.modules.block.BottleneckCSP", "ultralytics.nn.modules.block.C3Ghost", "ultralytics.nn.modules.block.C3x", "ultralytics.hub.export_fmts_hub", "ultralytics.hub.logout", "ultralytics.hub.utils.smart_request", "ultralytics.data.augment.classify_augmentations", "ultralytics.models.yolo.world.train_world.WorldTrainerFromScratch", "numpy", "ultralytics.YOLOE", "ultralytics.models.yolo.yoloe.YOLOEVPSegPredictor", "ultralytics.models.yolo.yoloe.YOLOEPESegTrainer" ], "chunk_id": "function_test_utils_benchmarks_86a2349a" }, { "content": "def test_utils_torchutils():\n \"\"\"Test Torch utility functions including profiling and FLOP calculations.\"\"\"\n from ultralytics.nn.modules.conv import Conv\n from ultralytics.utils.torch_utils import get_flops_with_torch_profiler, profile_ops, time_sync\n\n x = torch.randn(1, 64, 20, 20)\n m = Conv(64, 64, k=1, s=2)\n\n profile_ops(x, [m], n=3)\n get_flops_with_torch_profiler(m)\n time_sync()", "chunk_type": "function", "name": "test_utils_torchutils", "file_path": "ultralytics\\tests\\test_python.py", "start_line": 425, "end_line": 435, "start_col": 0, "end_col": 15, "parent_name": null, "docstring": "Test Torch utility functions including profiling and FLOP calculations.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "contextlib", "csv", "urllib", "copy.copy", "pathlib.Path", "cv2", "numpy", "pytest", "torch", "PIL.Image", "tests.CFG", "tests.MODEL", "tests.MODELS", "tests.SOURCE", "tests.SOURCES_LIST", "tests.TASK_MODEL_DATA", "tests.TMP", "ultralytics.RTDETR", "ultralytics.YOLO", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASKS", "ultralytics.data.build.load_inference_source", "ultralytics.data.utils.check_det_dataset", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.DEFAULT_CFG_PATH", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.ONLINE", "ultralytics.utils.ROOT", "ultralytics.utils.WEIGHTS_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.YAML", "ultralytics.utils.checks", "ultralytics.utils.is_dir_writeable", "ultralytics.utils.is_github_action_running", "ultralytics.utils.downloads.download", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.nn.tasks.DetectionModel", "ultralytics.data.split.autosplit", "ultralytics.data.utils.HUBDatasetStats", "ultralytics.utils.downloads.zip_directory", "ultralytics.data.converter.coco80_to_coco91_class", "ultralytics.data.converter.convert_coco", "ultralytics.data.annotator.auto_annotate", "ultralytics.hub.utils.Events", "ultralytics.cfg.check_dict_alignment", "ultralytics.cfg.copy_default_cfg", "ultralytics.cfg.smart_value", "ultralytics.utils.get_git_branch", "ultralytics.utils.get_git_origin_url", "ultralytics.utils.get_ubuntu_version", "ultralytics.utils.is_github_action_running", "ultralytics.utils.benchmarks.ProfileModels", "ultralytics.nn.modules.conv.Conv", "ultralytics.utils.torch_utils.get_flops_with_torch_profiler", "ultralytics.utils.torch_utils.profile_ops", "ultralytics.utils.torch_utils.time_sync", "ultralytics.utils.ops.ltwh2xywh", "ultralytics.utils.ops.ltwh2xyxy", "ultralytics.utils.ops.make_divisible", "ultralytics.utils.ops.xywh2ltwh", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xywhn2xyxy", "ultralytics.utils.ops.xywhr2xyxyxyxy", "ultralytics.utils.ops.xyxy2ltwh", "ultralytics.utils.ops.xyxy2xywh", "ultralytics.utils.ops.xyxy2xywhn", "ultralytics.utils.ops.xyxyxyxy2xywhr", "ultralytics.utils.files.file_age", "ultralytics.utils.files.file_date", "ultralytics.utils.files.get_latest_run", "ultralytics.utils.files.spaces_in_path", "unittest.mock.MagicMock", "unittest.mock.patch", "ultralytics.utils.patches.torch_save", "ultralytics.nn.modules.conv.CBAM", "ultralytics.nn.modules.conv.Conv2", "ultralytics.nn.modules.conv.ConvTranspose", "ultralytics.nn.modules.conv.DWConvTranspose2d", "ultralytics.nn.modules.conv.Focus", "ultralytics.nn.modules.block.C1", "ultralytics.nn.modules.block.C3TR", "ultralytics.nn.modules.block.BottleneckCSP", "ultralytics.nn.modules.block.C3Ghost", "ultralytics.nn.modules.block.C3x", "ultralytics.hub.export_fmts_hub", "ultralytics.hub.logout", "ultralytics.hub.utils.smart_request", "ultralytics.data.augment.classify_augmentations", "ultralytics.models.yolo.world.train_world.WorldTrainerFromScratch", "numpy", "ultralytics.YOLOE", "ultralytics.models.yolo.yoloe.YOLOEVPSegPredictor", "ultralytics.models.yolo.yoloe.YOLOEPESegTrainer" ], "chunk_id": "function_test_utils_torchutils_ab38ed71" }, { "content": "def test_utils_ops():\n \"\"\"Test utility operations for coordinate transformations and normalizations.\"\"\"\n from ultralytics.utils.ops import (\n ltwh2xywh,\n ltwh2xyxy,\n make_divisible,\n xywh2ltwh,\n xywh2xyxy,\n xywhn2xyxy,\n xywhr2xyxyxyxy,\n xyxy2ltwh,\n xyxy2xywh,\n xyxy2xywhn,\n xyxyxyxy2xywhr,\n )\n\n make_divisible(17, torch.tensor([8]))\n\n boxes = torch.rand(10, 4) # xywh\n torch.allclose(boxes, xyxy2xywh(xywh2xyxy(boxes)))\n torch.allclose(boxes, xyxy2xywhn(xywhn2xyxy(boxes)))\n torch.allclose(boxes, ltwh2xywh(xywh2ltwh(boxes)))\n torch.allclose(boxes, xyxy2ltwh(ltwh2xyxy(boxes)))\n\n boxes = torch.rand(10, 5) # xywhr for OBB\n boxes[:, 4] = torch.randn(10) * 30\n torch.allclose(boxes, xyxyxyxy2xywhr(xywhr2xyxyxyxy(boxes)), rtol=1e-3)", "chunk_type": "function", "name": "test_utils_ops", "file_path": "ultralytics\\tests\\test_python.py", "start_line": 438, "end_line": 464, "start_col": 0, "end_col": 75, "parent_name": null, "docstring": "Test utility operations for coordinate transformations and normalizations.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "contextlib", "csv", "urllib", "copy.copy", "pathlib.Path", "cv2", "numpy", "pytest", "torch", "PIL.Image", "tests.CFG", "tests.MODEL", "tests.MODELS", "tests.SOURCE", "tests.SOURCES_LIST", "tests.TASK_MODEL_DATA", "tests.TMP", "ultralytics.RTDETR", "ultralytics.YOLO", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASKS", "ultralytics.data.build.load_inference_source", "ultralytics.data.utils.check_det_dataset", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.DEFAULT_CFG_PATH", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.ONLINE", "ultralytics.utils.ROOT", "ultralytics.utils.WEIGHTS_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.YAML", "ultralytics.utils.checks", "ultralytics.utils.is_dir_writeable", "ultralytics.utils.is_github_action_running", "ultralytics.utils.downloads.download", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.nn.tasks.DetectionModel", "ultralytics.data.split.autosplit", "ultralytics.data.utils.HUBDatasetStats", "ultralytics.utils.downloads.zip_directory", "ultralytics.data.converter.coco80_to_coco91_class", "ultralytics.data.converter.convert_coco", "ultralytics.data.annotator.auto_annotate", "ultralytics.hub.utils.Events", "ultralytics.cfg.check_dict_alignment", "ultralytics.cfg.copy_default_cfg", "ultralytics.cfg.smart_value", "ultralytics.utils.get_git_branch", "ultralytics.utils.get_git_origin_url", "ultralytics.utils.get_ubuntu_version", "ultralytics.utils.is_github_action_running", "ultralytics.utils.benchmarks.ProfileModels", "ultralytics.nn.modules.conv.Conv", "ultralytics.utils.torch_utils.get_flops_with_torch_profiler", "ultralytics.utils.torch_utils.profile_ops", "ultralytics.utils.torch_utils.time_sync", "ultralytics.utils.ops.ltwh2xywh", "ultralytics.utils.ops.ltwh2xyxy", "ultralytics.utils.ops.make_divisible", "ultralytics.utils.ops.xywh2ltwh", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xywhn2xyxy", "ultralytics.utils.ops.xywhr2xyxyxyxy", "ultralytics.utils.ops.xyxy2ltwh", "ultralytics.utils.ops.xyxy2xywh", "ultralytics.utils.ops.xyxy2xywhn", "ultralytics.utils.ops.xyxyxyxy2xywhr", "ultralytics.utils.files.file_age", "ultralytics.utils.files.file_date", "ultralytics.utils.files.get_latest_run", "ultralytics.utils.files.spaces_in_path", "unittest.mock.MagicMock", "unittest.mock.patch", "ultralytics.utils.patches.torch_save", "ultralytics.nn.modules.conv.CBAM", "ultralytics.nn.modules.conv.Conv2", "ultralytics.nn.modules.conv.ConvTranspose", "ultralytics.nn.modules.conv.DWConvTranspose2d", "ultralytics.nn.modules.conv.Focus", "ultralytics.nn.modules.block.C1", "ultralytics.nn.modules.block.C3TR", "ultralytics.nn.modules.block.BottleneckCSP", "ultralytics.nn.modules.block.C3Ghost", "ultralytics.nn.modules.block.C3x", "ultralytics.hub.export_fmts_hub", "ultralytics.hub.logout", "ultralytics.hub.utils.smart_request", "ultralytics.data.augment.classify_augmentations", "ultralytics.models.yolo.world.train_world.WorldTrainerFromScratch", "numpy", "ultralytics.YOLOE", "ultralytics.models.yolo.yoloe.YOLOEVPSegPredictor", "ultralytics.models.yolo.yoloe.YOLOEPESegTrainer" ], "chunk_id": "function_test_utils_ops_fff86537" }, { "content": "def test_utils_files():\n \"\"\"Test file handling utilities including file age, date, and paths with spaces.\"\"\"\n from ultralytics.utils.files import file_age, file_date, get_latest_run, spaces_in_path\n\n file_age(SOURCE)\n file_date(SOURCE)\n get_latest_run(ROOT / \"runs\")\n\n path = TMP / \"path/with spaces\"\n path.mkdir(parents=True, exist_ok=True)\n with spaces_in_path(path) as new_path:\n print(new_path)", "chunk_type": "function", "name": "test_utils_files", "file_path": "ultralytics\\tests\\test_python.py", "start_line": 467, "end_line": 478, "start_col": 0, "end_col": 23, "parent_name": null, "docstring": "Test file handling utilities including file age, date, and paths with spaces.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "contextlib", "csv", "urllib", "copy.copy", "pathlib.Path", "cv2", "numpy", "pytest", "torch", "PIL.Image", "tests.CFG", "tests.MODEL", "tests.MODELS", "tests.SOURCE", "tests.SOURCES_LIST", "tests.TASK_MODEL_DATA", "tests.TMP", "ultralytics.RTDETR", "ultralytics.YOLO", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASKS", "ultralytics.data.build.load_inference_source", "ultralytics.data.utils.check_det_dataset", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.DEFAULT_CFG_PATH", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.ONLINE", "ultralytics.utils.ROOT", "ultralytics.utils.WEIGHTS_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.YAML", "ultralytics.utils.checks", "ultralytics.utils.is_dir_writeable", "ultralytics.utils.is_github_action_running", "ultralytics.utils.downloads.download", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.nn.tasks.DetectionModel", "ultralytics.data.split.autosplit", "ultralytics.data.utils.HUBDatasetStats", "ultralytics.utils.downloads.zip_directory", "ultralytics.data.converter.coco80_to_coco91_class", "ultralytics.data.converter.convert_coco", "ultralytics.data.annotator.auto_annotate", "ultralytics.hub.utils.Events", "ultralytics.cfg.check_dict_alignment", "ultralytics.cfg.copy_default_cfg", "ultralytics.cfg.smart_value", "ultralytics.utils.get_git_branch", "ultralytics.utils.get_git_origin_url", "ultralytics.utils.get_ubuntu_version", "ultralytics.utils.is_github_action_running", "ultralytics.utils.benchmarks.ProfileModels", "ultralytics.nn.modules.conv.Conv", "ultralytics.utils.torch_utils.get_flops_with_torch_profiler", "ultralytics.utils.torch_utils.profile_ops", "ultralytics.utils.torch_utils.time_sync", "ultralytics.utils.ops.ltwh2xywh", "ultralytics.utils.ops.ltwh2xyxy", "ultralytics.utils.ops.make_divisible", "ultralytics.utils.ops.xywh2ltwh", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xywhn2xyxy", "ultralytics.utils.ops.xywhr2xyxyxyxy", "ultralytics.utils.ops.xyxy2ltwh", "ultralytics.utils.ops.xyxy2xywh", "ultralytics.utils.ops.xyxy2xywhn", "ultralytics.utils.ops.xyxyxyxy2xywhr", "ultralytics.utils.files.file_age", "ultralytics.utils.files.file_date", "ultralytics.utils.files.get_latest_run", "ultralytics.utils.files.spaces_in_path", "unittest.mock.MagicMock", "unittest.mock.patch", "ultralytics.utils.patches.torch_save", "ultralytics.nn.modules.conv.CBAM", "ultralytics.nn.modules.conv.Conv2", "ultralytics.nn.modules.conv.ConvTranspose", "ultralytics.nn.modules.conv.DWConvTranspose2d", "ultralytics.nn.modules.conv.Focus", "ultralytics.nn.modules.block.C1", "ultralytics.nn.modules.block.C3TR", "ultralytics.nn.modules.block.BottleneckCSP", "ultralytics.nn.modules.block.C3Ghost", "ultralytics.nn.modules.block.C3x", "ultralytics.hub.export_fmts_hub", "ultralytics.hub.logout", "ultralytics.hub.utils.smart_request", "ultralytics.data.augment.classify_augmentations", "ultralytics.models.yolo.world.train_world.WorldTrainerFromScratch", "numpy", "ultralytics.YOLOE", "ultralytics.models.yolo.yoloe.YOLOEVPSegPredictor", "ultralytics.models.yolo.yoloe.YOLOEPESegTrainer" ], "chunk_id": "function_test_utils_files_9b2b0d08" }, { "content": "def test_utils_patches_torch_save():\n \"\"\"Test torch_save backoff when _torch_save raises RuntimeError.\"\"\"\n from unittest.mock import MagicMock, patch\n\n from ultralytics.utils.patches import torch_save\n\n mock = MagicMock(side_effect=RuntimeError)\n\n with patch(\"ultralytics.utils.patches._torch_save\", new=mock):\n with pytest.raises(RuntimeError):\n torch_save(torch.zeros(1), TMP / \"test.pt\")\n\n assert mock.call_count == 4, \"torch_save was not attempted the expected number of times\"", "chunk_type": "function", "name": "test_utils_patches_torch_save", "file_path": "ultralytics\\tests\\test_python.py", "start_line": 482, "end_line": 494, "start_col": 0, "end_col": 92, "parent_name": null, "docstring": "Test torch_save backoff when _torch_save raises RuntimeError.", "parameters": [], "return_type": null, "decorators": [ "pytest.mark.slow" ], "complexity_score": 1, "dependencies": [ "contextlib", "csv", "urllib", "copy.copy", "pathlib.Path", "cv2", "numpy", "pytest", "torch", "PIL.Image", "tests.CFG", "tests.MODEL", "tests.MODELS", "tests.SOURCE", "tests.SOURCES_LIST", "tests.TASK_MODEL_DATA", "tests.TMP", "ultralytics.RTDETR", "ultralytics.YOLO", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASKS", "ultralytics.data.build.load_inference_source", "ultralytics.data.utils.check_det_dataset", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.DEFAULT_CFG_PATH", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.ONLINE", "ultralytics.utils.ROOT", "ultralytics.utils.WEIGHTS_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.YAML", "ultralytics.utils.checks", "ultralytics.utils.is_dir_writeable", "ultralytics.utils.is_github_action_running", "ultralytics.utils.downloads.download", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.nn.tasks.DetectionModel", "ultralytics.data.split.autosplit", "ultralytics.data.utils.HUBDatasetStats", "ultralytics.utils.downloads.zip_directory", "ultralytics.data.converter.coco80_to_coco91_class", "ultralytics.data.converter.convert_coco", "ultralytics.data.annotator.auto_annotate", "ultralytics.hub.utils.Events", "ultralytics.cfg.check_dict_alignment", "ultralytics.cfg.copy_default_cfg", "ultralytics.cfg.smart_value", "ultralytics.utils.get_git_branch", "ultralytics.utils.get_git_origin_url", "ultralytics.utils.get_ubuntu_version", "ultralytics.utils.is_github_action_running", "ultralytics.utils.benchmarks.ProfileModels", "ultralytics.nn.modules.conv.Conv", "ultralytics.utils.torch_utils.get_flops_with_torch_profiler", "ultralytics.utils.torch_utils.profile_ops", "ultralytics.utils.torch_utils.time_sync", "ultralytics.utils.ops.ltwh2xywh", "ultralytics.utils.ops.ltwh2xyxy", "ultralytics.utils.ops.make_divisible", "ultralytics.utils.ops.xywh2ltwh", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xywhn2xyxy", "ultralytics.utils.ops.xywhr2xyxyxyxy", "ultralytics.utils.ops.xyxy2ltwh", "ultralytics.utils.ops.xyxy2xywh", "ultralytics.utils.ops.xyxy2xywhn", "ultralytics.utils.ops.xyxyxyxy2xywhr", "ultralytics.utils.files.file_age", "ultralytics.utils.files.file_date", "ultralytics.utils.files.get_latest_run", "ultralytics.utils.files.spaces_in_path", "unittest.mock.MagicMock", "unittest.mock.patch", "ultralytics.utils.patches.torch_save", "ultralytics.nn.modules.conv.CBAM", "ultralytics.nn.modules.conv.Conv2", "ultralytics.nn.modules.conv.ConvTranspose", "ultralytics.nn.modules.conv.DWConvTranspose2d", "ultralytics.nn.modules.conv.Focus", "ultralytics.nn.modules.block.C1", "ultralytics.nn.modules.block.C3TR", "ultralytics.nn.modules.block.BottleneckCSP", "ultralytics.nn.modules.block.C3Ghost", "ultralytics.nn.modules.block.C3x", "ultralytics.hub.export_fmts_hub", "ultralytics.hub.logout", "ultralytics.hub.utils.smart_request", "ultralytics.data.augment.classify_augmentations", "ultralytics.models.yolo.world.train_world.WorldTrainerFromScratch", "numpy", "ultralytics.YOLOE", "ultralytics.models.yolo.yoloe.YOLOEVPSegPredictor", "ultralytics.models.yolo.yoloe.YOLOEPESegTrainer" ], "chunk_id": "function_test_utils_patches_torch_save_b3af24b8" }, { "content": "def test_nn_modules_conv():\n \"\"\"Test Convolutional Neural Network modules including CBAM, Conv2, and ConvTranspose.\"\"\"\n from ultralytics.nn.modules.conv import CBAM, Conv2, ConvTranspose, DWConvTranspose2d, Focus\n\n c1, c2 = 8, 16 # input and output channels\n x = torch.zeros(4, c1, 10, 10) # BCHW\n\n # Run all modules not otherwise covered in tests\n DWConvTranspose2d(c1, c2)(x)\n ConvTranspose(c1, c2)(x)\n Focus(c1, c2)(x)\n CBAM(c1)(x)\n\n # Fuse ops\n m = Conv2(c1, c2)\n m.fuse_convs()\n m(x)", "chunk_type": "function", "name": "test_nn_modules_conv", "file_path": "ultralytics\\tests\\test_python.py", "start_line": 497, "end_line": 513, "start_col": 0, "end_col": 8, "parent_name": null, "docstring": "Test Convolutional Neural Network modules including CBAM, Conv2, and ConvTranspose.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "contextlib", "csv", "urllib", "copy.copy", "pathlib.Path", "cv2", "numpy", "pytest", "torch", "PIL.Image", "tests.CFG", "tests.MODEL", "tests.MODELS", "tests.SOURCE", "tests.SOURCES_LIST", "tests.TASK_MODEL_DATA", "tests.TMP", "ultralytics.RTDETR", "ultralytics.YOLO", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASKS", "ultralytics.data.build.load_inference_source", "ultralytics.data.utils.check_det_dataset", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.DEFAULT_CFG_PATH", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.ONLINE", "ultralytics.utils.ROOT", "ultralytics.utils.WEIGHTS_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.YAML", "ultralytics.utils.checks", "ultralytics.utils.is_dir_writeable", "ultralytics.utils.is_github_action_running", "ultralytics.utils.downloads.download", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.nn.tasks.DetectionModel", "ultralytics.data.split.autosplit", "ultralytics.data.utils.HUBDatasetStats", "ultralytics.utils.downloads.zip_directory", "ultralytics.data.converter.coco80_to_coco91_class", "ultralytics.data.converter.convert_coco", "ultralytics.data.annotator.auto_annotate", "ultralytics.hub.utils.Events", "ultralytics.cfg.check_dict_alignment", "ultralytics.cfg.copy_default_cfg", "ultralytics.cfg.smart_value", "ultralytics.utils.get_git_branch", "ultralytics.utils.get_git_origin_url", "ultralytics.utils.get_ubuntu_version", "ultralytics.utils.is_github_action_running", "ultralytics.utils.benchmarks.ProfileModels", "ultralytics.nn.modules.conv.Conv", "ultralytics.utils.torch_utils.get_flops_with_torch_profiler", "ultralytics.utils.torch_utils.profile_ops", "ultralytics.utils.torch_utils.time_sync", "ultralytics.utils.ops.ltwh2xywh", "ultralytics.utils.ops.ltwh2xyxy", "ultralytics.utils.ops.make_divisible", "ultralytics.utils.ops.xywh2ltwh", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xywhn2xyxy", "ultralytics.utils.ops.xywhr2xyxyxyxy", "ultralytics.utils.ops.xyxy2ltwh", "ultralytics.utils.ops.xyxy2xywh", "ultralytics.utils.ops.xyxy2xywhn", "ultralytics.utils.ops.xyxyxyxy2xywhr", "ultralytics.utils.files.file_age", "ultralytics.utils.files.file_date", "ultralytics.utils.files.get_latest_run", "ultralytics.utils.files.spaces_in_path", "unittest.mock.MagicMock", "unittest.mock.patch", "ultralytics.utils.patches.torch_save", "ultralytics.nn.modules.conv.CBAM", "ultralytics.nn.modules.conv.Conv2", "ultralytics.nn.modules.conv.ConvTranspose", "ultralytics.nn.modules.conv.DWConvTranspose2d", "ultralytics.nn.modules.conv.Focus", "ultralytics.nn.modules.block.C1", "ultralytics.nn.modules.block.C3TR", "ultralytics.nn.modules.block.BottleneckCSP", "ultralytics.nn.modules.block.C3Ghost", "ultralytics.nn.modules.block.C3x", "ultralytics.hub.export_fmts_hub", "ultralytics.hub.logout", "ultralytics.hub.utils.smart_request", "ultralytics.data.augment.classify_augmentations", "ultralytics.models.yolo.world.train_world.WorldTrainerFromScratch", "numpy", "ultralytics.YOLOE", "ultralytics.models.yolo.yoloe.YOLOEVPSegPredictor", "ultralytics.models.yolo.yoloe.YOLOEPESegTrainer" ], "chunk_id": "function_test_nn_modules_conv_39dac117" }, { "content": "def test_nn_modules_block():\n \"\"\"Test various neural network block modules.\"\"\"\n from ultralytics.nn.modules.block import C1, C3TR, BottleneckCSP, C3Ghost, C3x\n\n c1, c2 = 8, 16 # input and output channels\n x = torch.zeros(4, c1, 10, 10) # BCHW\n\n # Run all modules not otherwise covered in tests\n C1(c1, c2)(x)\n C3x(c1, c2)(x)\n C3TR(c1, c2)(x)\n C3Ghost(c1, c2)(x)\n BottleneckCSP(c1, c2)(x)", "chunk_type": "function", "name": "test_nn_modules_block", "file_path": "ultralytics\\tests\\test_python.py", "start_line": 516, "end_line": 528, "start_col": 0, "end_col": 28, "parent_name": null, "docstring": "Test various neural network block modules.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "contextlib", "csv", "urllib", "copy.copy", "pathlib.Path", "cv2", "numpy", "pytest", "torch", "PIL.Image", "tests.CFG", "tests.MODEL", "tests.MODELS", "tests.SOURCE", "tests.SOURCES_LIST", "tests.TASK_MODEL_DATA", "tests.TMP", "ultralytics.RTDETR", "ultralytics.YOLO", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASKS", "ultralytics.data.build.load_inference_source", "ultralytics.data.utils.check_det_dataset", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.DEFAULT_CFG_PATH", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.ONLINE", "ultralytics.utils.ROOT", "ultralytics.utils.WEIGHTS_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.YAML", "ultralytics.utils.checks", "ultralytics.utils.is_dir_writeable", "ultralytics.utils.is_github_action_running", "ultralytics.utils.downloads.download", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.nn.tasks.DetectionModel", "ultralytics.data.split.autosplit", "ultralytics.data.utils.HUBDatasetStats", "ultralytics.utils.downloads.zip_directory", "ultralytics.data.converter.coco80_to_coco91_class", "ultralytics.data.converter.convert_coco", "ultralytics.data.annotator.auto_annotate", "ultralytics.hub.utils.Events", "ultralytics.cfg.check_dict_alignment", "ultralytics.cfg.copy_default_cfg", "ultralytics.cfg.smart_value", "ultralytics.utils.get_git_branch", "ultralytics.utils.get_git_origin_url", "ultralytics.utils.get_ubuntu_version", "ultralytics.utils.is_github_action_running", "ultralytics.utils.benchmarks.ProfileModels", "ultralytics.nn.modules.conv.Conv", "ultralytics.utils.torch_utils.get_flops_with_torch_profiler", "ultralytics.utils.torch_utils.profile_ops", "ultralytics.utils.torch_utils.time_sync", "ultralytics.utils.ops.ltwh2xywh", "ultralytics.utils.ops.ltwh2xyxy", "ultralytics.utils.ops.make_divisible", "ultralytics.utils.ops.xywh2ltwh", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xywhn2xyxy", "ultralytics.utils.ops.xywhr2xyxyxyxy", "ultralytics.utils.ops.xyxy2ltwh", "ultralytics.utils.ops.xyxy2xywh", "ultralytics.utils.ops.xyxy2xywhn", "ultralytics.utils.ops.xyxyxyxy2xywhr", "ultralytics.utils.files.file_age", "ultralytics.utils.files.file_date", "ultralytics.utils.files.get_latest_run", "ultralytics.utils.files.spaces_in_path", "unittest.mock.MagicMock", "unittest.mock.patch", "ultralytics.utils.patches.torch_save", "ultralytics.nn.modules.conv.CBAM", "ultralytics.nn.modules.conv.Conv2", "ultralytics.nn.modules.conv.ConvTranspose", "ultralytics.nn.modules.conv.DWConvTranspose2d", "ultralytics.nn.modules.conv.Focus", "ultralytics.nn.modules.block.C1", "ultralytics.nn.modules.block.C3TR", "ultralytics.nn.modules.block.BottleneckCSP", "ultralytics.nn.modules.block.C3Ghost", "ultralytics.nn.modules.block.C3x", "ultralytics.hub.export_fmts_hub", "ultralytics.hub.logout", "ultralytics.hub.utils.smart_request", "ultralytics.data.augment.classify_augmentations", "ultralytics.models.yolo.world.train_world.WorldTrainerFromScratch", "numpy", "ultralytics.YOLOE", "ultralytics.models.yolo.yoloe.YOLOEVPSegPredictor", "ultralytics.models.yolo.yoloe.YOLOEPESegTrainer" ], "chunk_id": "function_test_nn_modules_block_a339f61e" }, { "content": "def test_hub():\n \"\"\"Test Ultralytics HUB functionalities.\"\"\"\n from ultralytics.hub import export_fmts_hub, logout\n from ultralytics.hub.utils import smart_request\n\n export_fmts_hub()\n logout()\n smart_request(\"GET\", \"https://github.com\", progress=True)", "chunk_type": "function", "name": "test_hub", "file_path": "ultralytics\\tests\\test_python.py", "start_line": 532, "end_line": 539, "start_col": 0, "end_col": 61, "parent_name": null, "docstring": "Test Ultralytics HUB functionalities.", "parameters": [], "return_type": null, "decorators": [ "pytest.mark.skipif(not ONLINE, reason='environment is offline')" ], "complexity_score": 1, "dependencies": [ "contextlib", "csv", "urllib", "copy.copy", "pathlib.Path", "cv2", "numpy", "pytest", "torch", "PIL.Image", "tests.CFG", "tests.MODEL", "tests.MODELS", "tests.SOURCE", "tests.SOURCES_LIST", "tests.TASK_MODEL_DATA", "tests.TMP", "ultralytics.RTDETR", "ultralytics.YOLO", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASKS", "ultralytics.data.build.load_inference_source", "ultralytics.data.utils.check_det_dataset", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.DEFAULT_CFG_PATH", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.ONLINE", "ultralytics.utils.ROOT", "ultralytics.utils.WEIGHTS_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.YAML", "ultralytics.utils.checks", "ultralytics.utils.is_dir_writeable", "ultralytics.utils.is_github_action_running", "ultralytics.utils.downloads.download", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.nn.tasks.DetectionModel", "ultralytics.data.split.autosplit", "ultralytics.data.utils.HUBDatasetStats", "ultralytics.utils.downloads.zip_directory", "ultralytics.data.converter.coco80_to_coco91_class", "ultralytics.data.converter.convert_coco", "ultralytics.data.annotator.auto_annotate", "ultralytics.hub.utils.Events", "ultralytics.cfg.check_dict_alignment", "ultralytics.cfg.copy_default_cfg", "ultralytics.cfg.smart_value", "ultralytics.utils.get_git_branch", "ultralytics.utils.get_git_origin_url", "ultralytics.utils.get_ubuntu_version", "ultralytics.utils.is_github_action_running", "ultralytics.utils.benchmarks.ProfileModels", "ultralytics.nn.modules.conv.Conv", "ultralytics.utils.torch_utils.get_flops_with_torch_profiler", "ultralytics.utils.torch_utils.profile_ops", "ultralytics.utils.torch_utils.time_sync", "ultralytics.utils.ops.ltwh2xywh", "ultralytics.utils.ops.ltwh2xyxy", "ultralytics.utils.ops.make_divisible", "ultralytics.utils.ops.xywh2ltwh", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xywhn2xyxy", "ultralytics.utils.ops.xywhr2xyxyxyxy", "ultralytics.utils.ops.xyxy2ltwh", "ultralytics.utils.ops.xyxy2xywh", "ultralytics.utils.ops.xyxy2xywhn", "ultralytics.utils.ops.xyxyxyxy2xywhr", "ultralytics.utils.files.file_age", "ultralytics.utils.files.file_date", "ultralytics.utils.files.get_latest_run", "ultralytics.utils.files.spaces_in_path", "unittest.mock.MagicMock", "unittest.mock.patch", "ultralytics.utils.patches.torch_save", "ultralytics.nn.modules.conv.CBAM", "ultralytics.nn.modules.conv.Conv2", "ultralytics.nn.modules.conv.ConvTranspose", "ultralytics.nn.modules.conv.DWConvTranspose2d", "ultralytics.nn.modules.conv.Focus", "ultralytics.nn.modules.block.C1", "ultralytics.nn.modules.block.C3TR", "ultralytics.nn.modules.block.BottleneckCSP", "ultralytics.nn.modules.block.C3Ghost", "ultralytics.nn.modules.block.C3x", "ultralytics.hub.export_fmts_hub", "ultralytics.hub.logout", "ultralytics.hub.utils.smart_request", "ultralytics.data.augment.classify_augmentations", "ultralytics.models.yolo.world.train_world.WorldTrainerFromScratch", "numpy", "ultralytics.YOLOE", "ultralytics.models.yolo.yoloe.YOLOEVPSegPredictor", "ultralytics.models.yolo.yoloe.YOLOEPESegTrainer" ], "chunk_id": "function_test_hub_8fd5499b" }, { "content": "def image():\n \"\"\"Load and return an image from a predefined source.\"\"\"\n return cv2.imread(str(SOURCE))", "chunk_type": "function", "name": "image", "file_path": "ultralytics\\tests\\test_python.py", "start_line": 543, "end_line": 545, "start_col": 0, "end_col": 34, "parent_name": null, "docstring": "Load and return an image from a predefined source.", "parameters": [], "return_type": null, "decorators": [ "pytest.fixture" ], "complexity_score": 1, "dependencies": [ "contextlib", "csv", "urllib", "copy.copy", "pathlib.Path", "cv2", "numpy", "pytest", "torch", "PIL.Image", "tests.CFG", "tests.MODEL", "tests.MODELS", "tests.SOURCE", "tests.SOURCES_LIST", "tests.TASK_MODEL_DATA", "tests.TMP", "ultralytics.RTDETR", "ultralytics.YOLO", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASKS", "ultralytics.data.build.load_inference_source", "ultralytics.data.utils.check_det_dataset", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.DEFAULT_CFG_PATH", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.ONLINE", "ultralytics.utils.ROOT", "ultralytics.utils.WEIGHTS_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.YAML", "ultralytics.utils.checks", "ultralytics.utils.is_dir_writeable", "ultralytics.utils.is_github_action_running", "ultralytics.utils.downloads.download", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.nn.tasks.DetectionModel", "ultralytics.data.split.autosplit", "ultralytics.data.utils.HUBDatasetStats", "ultralytics.utils.downloads.zip_directory", "ultralytics.data.converter.coco80_to_coco91_class", "ultralytics.data.converter.convert_coco", "ultralytics.data.annotator.auto_annotate", "ultralytics.hub.utils.Events", "ultralytics.cfg.check_dict_alignment", "ultralytics.cfg.copy_default_cfg", "ultralytics.cfg.smart_value", "ultralytics.utils.get_git_branch", "ultralytics.utils.get_git_origin_url", "ultralytics.utils.get_ubuntu_version", "ultralytics.utils.is_github_action_running", "ultralytics.utils.benchmarks.ProfileModels", "ultralytics.nn.modules.conv.Conv", "ultralytics.utils.torch_utils.get_flops_with_torch_profiler", "ultralytics.utils.torch_utils.profile_ops", "ultralytics.utils.torch_utils.time_sync", "ultralytics.utils.ops.ltwh2xywh", "ultralytics.utils.ops.ltwh2xyxy", "ultralytics.utils.ops.make_divisible", "ultralytics.utils.ops.xywh2ltwh", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xywhn2xyxy", "ultralytics.utils.ops.xywhr2xyxyxyxy", "ultralytics.utils.ops.xyxy2ltwh", "ultralytics.utils.ops.xyxy2xywh", "ultralytics.utils.ops.xyxy2xywhn", "ultralytics.utils.ops.xyxyxyxy2xywhr", "ultralytics.utils.files.file_age", "ultralytics.utils.files.file_date", "ultralytics.utils.files.get_latest_run", "ultralytics.utils.files.spaces_in_path", "unittest.mock.MagicMock", "unittest.mock.patch", "ultralytics.utils.patches.torch_save", "ultralytics.nn.modules.conv.CBAM", "ultralytics.nn.modules.conv.Conv2", "ultralytics.nn.modules.conv.ConvTranspose", "ultralytics.nn.modules.conv.DWConvTranspose2d", "ultralytics.nn.modules.conv.Focus", "ultralytics.nn.modules.block.C1", "ultralytics.nn.modules.block.C3TR", "ultralytics.nn.modules.block.BottleneckCSP", "ultralytics.nn.modules.block.C3Ghost", "ultralytics.nn.modules.block.C3x", "ultralytics.hub.export_fmts_hub", "ultralytics.hub.logout", "ultralytics.hub.utils.smart_request", "ultralytics.data.augment.classify_augmentations", "ultralytics.models.yolo.world.train_world.WorldTrainerFromScratch", "numpy", "ultralytics.YOLOE", "ultralytics.models.yolo.yoloe.YOLOEVPSegPredictor", "ultralytics.models.yolo.yoloe.YOLOEPESegTrainer" ], "chunk_id": "function_image_f73df408" }, { "content": "def test_classify_transforms_train(image, auto_augment, erasing, force_color_jitter):\n \"\"\"Test classification transforms during training with various augmentations.\"\"\"\n from ultralytics.data.augment import classify_augmentations\n\n transform = classify_augmentations(\n size=224,\n mean=(0.5, 0.5, 0.5),\n std=(0.5, 0.5, 0.5),\n scale=(0.08, 1.0),\n ratio=(3.0 / 4.0, 4.0 / 3.0),\n hflip=0.5,\n vflip=0.5,\n auto_augment=auto_augment,\n hsv_h=0.015,\n hsv_s=0.4,\n hsv_v=0.4,\n force_color_jitter=force_color_jitter,\n erasing=erasing,\n )\n\n transformed_image = transform(Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)))\n\n assert transformed_image.shape == (3, 224, 224)\n assert torch.is_tensor(transformed_image)\n assert transformed_image.dtype == torch.float32", "chunk_type": "function", "name": "test_classify_transforms_train", "file_path": "ultralytics\\tests\\test_python.py", "start_line": 557, "end_line": 581, "start_col": 0, "end_col": 51, "parent_name": null, "docstring": "Test classification transforms during training with various augmentations.", "parameters": [ "image", "auto_augment", "erasing", "force_color_jitter" ], "return_type": null, "decorators": [ "pytest.mark.parametrize('auto_augment, erasing, force_color_jitter', [(None, 0.0, False), ('randaugment', 0.5, True), ('augmix', 0.2, False), ('autoaugment', 0.0, True)])" ], "complexity_score": 1, "dependencies": [ "contextlib", "csv", "urllib", "copy.copy", "pathlib.Path", "cv2", "numpy", "pytest", "torch", "PIL.Image", "tests.CFG", "tests.MODEL", "tests.MODELS", "tests.SOURCE", "tests.SOURCES_LIST", "tests.TASK_MODEL_DATA", "tests.TMP", "ultralytics.RTDETR", "ultralytics.YOLO", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASKS", "ultralytics.data.build.load_inference_source", "ultralytics.data.utils.check_det_dataset", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.DEFAULT_CFG_PATH", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.ONLINE", "ultralytics.utils.ROOT", "ultralytics.utils.WEIGHTS_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.YAML", "ultralytics.utils.checks", "ultralytics.utils.is_dir_writeable", "ultralytics.utils.is_github_action_running", "ultralytics.utils.downloads.download", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.nn.tasks.DetectionModel", "ultralytics.data.split.autosplit", "ultralytics.data.utils.HUBDatasetStats", "ultralytics.utils.downloads.zip_directory", "ultralytics.data.converter.coco80_to_coco91_class", "ultralytics.data.converter.convert_coco", "ultralytics.data.annotator.auto_annotate", "ultralytics.hub.utils.Events", "ultralytics.cfg.check_dict_alignment", "ultralytics.cfg.copy_default_cfg", "ultralytics.cfg.smart_value", "ultralytics.utils.get_git_branch", "ultralytics.utils.get_git_origin_url", "ultralytics.utils.get_ubuntu_version", "ultralytics.utils.is_github_action_running", "ultralytics.utils.benchmarks.ProfileModels", "ultralytics.nn.modules.conv.Conv", "ultralytics.utils.torch_utils.get_flops_with_torch_profiler", "ultralytics.utils.torch_utils.profile_ops", "ultralytics.utils.torch_utils.time_sync", "ultralytics.utils.ops.ltwh2xywh", "ultralytics.utils.ops.ltwh2xyxy", "ultralytics.utils.ops.make_divisible", "ultralytics.utils.ops.xywh2ltwh", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xywhn2xyxy", "ultralytics.utils.ops.xywhr2xyxyxyxy", "ultralytics.utils.ops.xyxy2ltwh", "ultralytics.utils.ops.xyxy2xywh", "ultralytics.utils.ops.xyxy2xywhn", "ultralytics.utils.ops.xyxyxyxy2xywhr", "ultralytics.utils.files.file_age", "ultralytics.utils.files.file_date", "ultralytics.utils.files.get_latest_run", "ultralytics.utils.files.spaces_in_path", "unittest.mock.MagicMock", "unittest.mock.patch", "ultralytics.utils.patches.torch_save", "ultralytics.nn.modules.conv.CBAM", "ultralytics.nn.modules.conv.Conv2", "ultralytics.nn.modules.conv.ConvTranspose", "ultralytics.nn.modules.conv.DWConvTranspose2d", "ultralytics.nn.modules.conv.Focus", "ultralytics.nn.modules.block.C1", "ultralytics.nn.modules.block.C3TR", "ultralytics.nn.modules.block.BottleneckCSP", "ultralytics.nn.modules.block.C3Ghost", "ultralytics.nn.modules.block.C3x", "ultralytics.hub.export_fmts_hub", "ultralytics.hub.logout", "ultralytics.hub.utils.smart_request", "ultralytics.data.augment.classify_augmentations", "ultralytics.models.yolo.world.train_world.WorldTrainerFromScratch", "numpy", "ultralytics.YOLOE", "ultralytics.models.yolo.yoloe.YOLOEVPSegPredictor", "ultralytics.models.yolo.yoloe.YOLOEPESegTrainer" ], "chunk_id": "function_test_classify_transforms_train_3015fce6" }, { "content": "def test_model_tune():\n \"\"\"Tune YOLO model for performance improvement.\"\"\"\n YOLO(\"yolo11n-pose.pt\").tune(data=\"coco8-pose.yaml\", plots=False, imgsz=32, epochs=1, iterations=2, device=\"cpu\")\n YOLO(\"yolo11n-cls.pt\").tune(data=\"imagenet10\", plots=False, imgsz=32, epochs=1, iterations=2, device=\"cpu\")", "chunk_type": "function", "name": "test_model_tune", "file_path": "ultralytics\\tests\\test_python.py", "start_line": 586, "end_line": 589, "start_col": 0, "end_col": 111, "parent_name": null, "docstring": "Tune YOLO model for performance improvement.", "parameters": [], "return_type": null, "decorators": [ "pytest.mark.slow", "pytest.mark.skipif(not ONLINE, reason='environment is offline')" ], "complexity_score": 1, "dependencies": [ "contextlib", "csv", "urllib", "copy.copy", "pathlib.Path", "cv2", "numpy", "pytest", "torch", "PIL.Image", "tests.CFG", "tests.MODEL", "tests.MODELS", "tests.SOURCE", "tests.SOURCES_LIST", "tests.TASK_MODEL_DATA", "tests.TMP", "ultralytics.RTDETR", "ultralytics.YOLO", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASKS", "ultralytics.data.build.load_inference_source", "ultralytics.data.utils.check_det_dataset", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.DEFAULT_CFG_PATH", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.ONLINE", "ultralytics.utils.ROOT", "ultralytics.utils.WEIGHTS_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.YAML", "ultralytics.utils.checks", "ultralytics.utils.is_dir_writeable", "ultralytics.utils.is_github_action_running", "ultralytics.utils.downloads.download", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.nn.tasks.DetectionModel", "ultralytics.data.split.autosplit", "ultralytics.data.utils.HUBDatasetStats", "ultralytics.utils.downloads.zip_directory", "ultralytics.data.converter.coco80_to_coco91_class", "ultralytics.data.converter.convert_coco", "ultralytics.data.annotator.auto_annotate", "ultralytics.hub.utils.Events", "ultralytics.cfg.check_dict_alignment", "ultralytics.cfg.copy_default_cfg", "ultralytics.cfg.smart_value", "ultralytics.utils.get_git_branch", "ultralytics.utils.get_git_origin_url", "ultralytics.utils.get_ubuntu_version", "ultralytics.utils.is_github_action_running", "ultralytics.utils.benchmarks.ProfileModels", "ultralytics.nn.modules.conv.Conv", "ultralytics.utils.torch_utils.get_flops_with_torch_profiler", "ultralytics.utils.torch_utils.profile_ops", "ultralytics.utils.torch_utils.time_sync", "ultralytics.utils.ops.ltwh2xywh", "ultralytics.utils.ops.ltwh2xyxy", "ultralytics.utils.ops.make_divisible", "ultralytics.utils.ops.xywh2ltwh", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xywhn2xyxy", "ultralytics.utils.ops.xywhr2xyxyxyxy", "ultralytics.utils.ops.xyxy2ltwh", "ultralytics.utils.ops.xyxy2xywh", "ultralytics.utils.ops.xyxy2xywhn", "ultralytics.utils.ops.xyxyxyxy2xywhr", "ultralytics.utils.files.file_age", "ultralytics.utils.files.file_date", "ultralytics.utils.files.get_latest_run", "ultralytics.utils.files.spaces_in_path", "unittest.mock.MagicMock", "unittest.mock.patch", "ultralytics.utils.patches.torch_save", "ultralytics.nn.modules.conv.CBAM", "ultralytics.nn.modules.conv.Conv2", "ultralytics.nn.modules.conv.ConvTranspose", "ultralytics.nn.modules.conv.DWConvTranspose2d", "ultralytics.nn.modules.conv.Focus", "ultralytics.nn.modules.block.C1", "ultralytics.nn.modules.block.C3TR", "ultralytics.nn.modules.block.BottleneckCSP", "ultralytics.nn.modules.block.C3Ghost", "ultralytics.nn.modules.block.C3x", "ultralytics.hub.export_fmts_hub", "ultralytics.hub.logout", "ultralytics.hub.utils.smart_request", "ultralytics.data.augment.classify_augmentations", "ultralytics.models.yolo.world.train_world.WorldTrainerFromScratch", "numpy", "ultralytics.YOLOE", "ultralytics.models.yolo.yoloe.YOLOEVPSegPredictor", "ultralytics.models.yolo.yoloe.YOLOEPESegTrainer" ], "chunk_id": "function_test_model_tune_bd3baae8" }, { "content": "def test_model_embeddings():\n \"\"\"Test YOLO model embeddings extraction functionality.\"\"\"\n model_detect = YOLO(MODEL)\n model_segment = YOLO(WEIGHTS_DIR / \"yolo11n-seg.pt\")\n\n for batch in [SOURCE], [SOURCE, SOURCE]: # test batch size 1 and 2\n assert len(model_detect.embed(source=batch, imgsz=32)) == len(batch)\n assert len(model_segment.embed(source=batch, imgsz=32)) == len(batch)", "chunk_type": "function", "name": "test_model_embeddings", "file_path": "ultralytics\\tests\\test_python.py", "start_line": 592, "end_line": 599, "start_col": 0, "end_col": 77, "parent_name": null, "docstring": "Test YOLO model embeddings extraction functionality.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 2, "dependencies": [ "contextlib", "csv", "urllib", "copy.copy", "pathlib.Path", "cv2", "numpy", "pytest", "torch", "PIL.Image", "tests.CFG", "tests.MODEL", "tests.MODELS", "tests.SOURCE", "tests.SOURCES_LIST", "tests.TASK_MODEL_DATA", "tests.TMP", "ultralytics.RTDETR", "ultralytics.YOLO", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASKS", "ultralytics.data.build.load_inference_source", "ultralytics.data.utils.check_det_dataset", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.DEFAULT_CFG_PATH", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.ONLINE", "ultralytics.utils.ROOT", "ultralytics.utils.WEIGHTS_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.YAML", "ultralytics.utils.checks", "ultralytics.utils.is_dir_writeable", "ultralytics.utils.is_github_action_running", "ultralytics.utils.downloads.download", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.nn.tasks.DetectionModel", "ultralytics.data.split.autosplit", "ultralytics.data.utils.HUBDatasetStats", "ultralytics.utils.downloads.zip_directory", "ultralytics.data.converter.coco80_to_coco91_class", "ultralytics.data.converter.convert_coco", "ultralytics.data.annotator.auto_annotate", "ultralytics.hub.utils.Events", "ultralytics.cfg.check_dict_alignment", "ultralytics.cfg.copy_default_cfg", "ultralytics.cfg.smart_value", "ultralytics.utils.get_git_branch", "ultralytics.utils.get_git_origin_url", "ultralytics.utils.get_ubuntu_version", "ultralytics.utils.is_github_action_running", "ultralytics.utils.benchmarks.ProfileModels", "ultralytics.nn.modules.conv.Conv", "ultralytics.utils.torch_utils.get_flops_with_torch_profiler", "ultralytics.utils.torch_utils.profile_ops", "ultralytics.utils.torch_utils.time_sync", "ultralytics.utils.ops.ltwh2xywh", "ultralytics.utils.ops.ltwh2xyxy", "ultralytics.utils.ops.make_divisible", "ultralytics.utils.ops.xywh2ltwh", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xywhn2xyxy", "ultralytics.utils.ops.xywhr2xyxyxyxy", "ultralytics.utils.ops.xyxy2ltwh", "ultralytics.utils.ops.xyxy2xywh", "ultralytics.utils.ops.xyxy2xywhn", "ultralytics.utils.ops.xyxyxyxy2xywhr", "ultralytics.utils.files.file_age", "ultralytics.utils.files.file_date", "ultralytics.utils.files.get_latest_run", "ultralytics.utils.files.spaces_in_path", "unittest.mock.MagicMock", "unittest.mock.patch", "ultralytics.utils.patches.torch_save", "ultralytics.nn.modules.conv.CBAM", "ultralytics.nn.modules.conv.Conv2", "ultralytics.nn.modules.conv.ConvTranspose", "ultralytics.nn.modules.conv.DWConvTranspose2d", "ultralytics.nn.modules.conv.Focus", "ultralytics.nn.modules.block.C1", "ultralytics.nn.modules.block.C3TR", "ultralytics.nn.modules.block.BottleneckCSP", "ultralytics.nn.modules.block.C3Ghost", "ultralytics.nn.modules.block.C3x", "ultralytics.hub.export_fmts_hub", "ultralytics.hub.logout", "ultralytics.hub.utils.smart_request", "ultralytics.data.augment.classify_augmentations", "ultralytics.models.yolo.world.train_world.WorldTrainerFromScratch", "numpy", "ultralytics.YOLOE", "ultralytics.models.yolo.yoloe.YOLOEVPSegPredictor", "ultralytics.models.yolo.yoloe.YOLOEPESegTrainer" ], "chunk_id": "function_test_model_embeddings_d2d78945" }, { "content": "def test_yolo_world():\n \"\"\"Test YOLO world models with CLIP support.\"\"\"\n model = YOLO(WEIGHTS_DIR / \"yolov8s-world.pt\") # no YOLO11n-world model yet\n model.set_classes([\"tree\", \"window\"])\n model(SOURCE, conf=0.01)\n\n model = YOLO(WEIGHTS_DIR / \"yolov8s-worldv2.pt\") # no YOLO11n-world model yet\n # Training from a pretrained model. Eval is included at the final stage of training.\n # Use dota8.yaml which has fewer categories to reduce the inference time of CLIP model\n model.train(\n data=\"dota8.yaml\",\n epochs=1,\n imgsz=32,\n cache=\"disk\",\n close_mosaic=1,\n )\n\n # test WorWorldTrainerFromScratch\n from ultralytics.models.yolo.world.train_world import WorldTrainerFromScratch\n\n model = YOLO(\"yolov8s-worldv2.yaml\") # no YOLO11n-world model yet\n model.train(\n data={\"train\": {\"yolo_data\": [\"dota8.yaml\"]}, \"val\": {\"yolo_data\": [\"dota8.yaml\"]}},\n epochs=1,\n imgsz=32,\n cache=\"disk\",\n close_mosaic=1,\n trainer=WorldTrainerFromScratch,\n )", "chunk_type": "function", "name": "test_yolo_world", "file_path": "ultralytics\\tests\\test_python.py", "start_line": 607, "end_line": 635, "start_col": 0, "end_col": 5, "parent_name": null, "docstring": "Test YOLO world models with CLIP support.", "parameters": [], "return_type": null, "decorators": [ "pytest.mark.skipif(checks.IS_PYTHON_3_12, reason='YOLOWorld with CLIP is not supported in Python 3.12')", "pytest.mark.skipif(checks.IS_PYTHON_3_8 and LINUX and ARM64, reason='YOLOWorld with CLIP is not supported in Python 3.8 and aarch64 Linux')" ], "complexity_score": 1, "dependencies": [ "contextlib", "csv", "urllib", "copy.copy", "pathlib.Path", "cv2", "numpy", "pytest", "torch", "PIL.Image", "tests.CFG", "tests.MODEL", "tests.MODELS", "tests.SOURCE", "tests.SOURCES_LIST", "tests.TASK_MODEL_DATA", "tests.TMP", "ultralytics.RTDETR", "ultralytics.YOLO", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASKS", "ultralytics.data.build.load_inference_source", "ultralytics.data.utils.check_det_dataset", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.DEFAULT_CFG_PATH", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.ONLINE", "ultralytics.utils.ROOT", "ultralytics.utils.WEIGHTS_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.YAML", "ultralytics.utils.checks", "ultralytics.utils.is_dir_writeable", "ultralytics.utils.is_github_action_running", "ultralytics.utils.downloads.download", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.nn.tasks.DetectionModel", "ultralytics.data.split.autosplit", "ultralytics.data.utils.HUBDatasetStats", "ultralytics.utils.downloads.zip_directory", "ultralytics.data.converter.coco80_to_coco91_class", "ultralytics.data.converter.convert_coco", "ultralytics.data.annotator.auto_annotate", "ultralytics.hub.utils.Events", "ultralytics.cfg.check_dict_alignment", "ultralytics.cfg.copy_default_cfg", "ultralytics.cfg.smart_value", "ultralytics.utils.get_git_branch", "ultralytics.utils.get_git_origin_url", "ultralytics.utils.get_ubuntu_version", "ultralytics.utils.is_github_action_running", "ultralytics.utils.benchmarks.ProfileModels", "ultralytics.nn.modules.conv.Conv", "ultralytics.utils.torch_utils.get_flops_with_torch_profiler", "ultralytics.utils.torch_utils.profile_ops", "ultralytics.utils.torch_utils.time_sync", "ultralytics.utils.ops.ltwh2xywh", "ultralytics.utils.ops.ltwh2xyxy", "ultralytics.utils.ops.make_divisible", "ultralytics.utils.ops.xywh2ltwh", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xywhn2xyxy", "ultralytics.utils.ops.xywhr2xyxyxyxy", "ultralytics.utils.ops.xyxy2ltwh", "ultralytics.utils.ops.xyxy2xywh", "ultralytics.utils.ops.xyxy2xywhn", "ultralytics.utils.ops.xyxyxyxy2xywhr", "ultralytics.utils.files.file_age", "ultralytics.utils.files.file_date", "ultralytics.utils.files.get_latest_run", "ultralytics.utils.files.spaces_in_path", "unittest.mock.MagicMock", "unittest.mock.patch", "ultralytics.utils.patches.torch_save", "ultralytics.nn.modules.conv.CBAM", "ultralytics.nn.modules.conv.Conv2", "ultralytics.nn.modules.conv.ConvTranspose", "ultralytics.nn.modules.conv.DWConvTranspose2d", "ultralytics.nn.modules.conv.Focus", "ultralytics.nn.modules.block.C1", "ultralytics.nn.modules.block.C3TR", "ultralytics.nn.modules.block.BottleneckCSP", "ultralytics.nn.modules.block.C3Ghost", "ultralytics.nn.modules.block.C3x", "ultralytics.hub.export_fmts_hub", "ultralytics.hub.logout", "ultralytics.hub.utils.smart_request", "ultralytics.data.augment.classify_augmentations", "ultralytics.models.yolo.world.train_world.WorldTrainerFromScratch", "numpy", "ultralytics.YOLOE", "ultralytics.models.yolo.yoloe.YOLOEVPSegPredictor", "ultralytics.models.yolo.yoloe.YOLOEPESegTrainer" ], "chunk_id": "function_test_yolo_world_d8a74ff9" }, { "content": "def test_yoloe():\n \"\"\"Test YOLOE models with MobileClip support.\"\"\"\n # Predict\n # text-prompts\n model = YOLO(WEIGHTS_DIR / \"yoloe-11s-seg.pt\")\n names = [\"person\", \"bus\"]\n model.set_classes(names, model.get_text_pe(names))\n model(SOURCE, conf=0.01)\n\n import numpy as np\n\n from ultralytics import YOLOE\n from ultralytics.models.yolo.yoloe import YOLOEVPSegPredictor\n\n # visual-prompts\n visuals = dict(\n bboxes=np.array(\n [[221.52, 405.8, 344.98, 857.54], [120, 425, 160, 445]],\n ),\n cls=np.array([0, 1]),\n )\n model.predict(\n SOURCE,\n visual_prompts=visuals,\n predictor=YOLOEVPSegPredictor,\n )\n\n # Val\n model = YOLOE(WEIGHTS_DIR / \"yoloe-11s-seg.pt\")\n # text prompts\n model.val(data=\"coco128-seg.yaml\", imgsz=32)\n # visual prompts\n model.val(data=\"coco128-seg.yaml\", load_vp=True, imgsz=32)\n\n # Train, fine-tune\n from ultralytics.models.yolo.yoloe import YOLOEPESegTrainer\n\n model = YOLOE(\"yoloe-11s-seg.pt\")\n model.train(\n data=\"coco128-seg.yaml\",\n epochs=1,\n close_mosaic=1,\n trainer=YOLOEPESegTrainer,\n imgsz=32,\n )\n\n # prompt-free\n # predict\n model = YOLOE(WEIGHTS_DIR / \"yoloe-11s-seg-pf.pt\")\n model.predict(SOURCE)\n # val\n model = YOLOE(\"yoloe-11s-seg.pt\") # or select yoloe-m/l-seg.pt for different sizes\n model.val(data=\"coco128-seg.yaml\", imgsz=32)", "chunk_type": "function", "name": "test_yoloe", "file_path": "ultralytics\\tests\\test_python.py", "start_line": 643, "end_line": 695, "start_col": 0, "end_col": 48, "parent_name": null, "docstring": "Test YOLOE models with MobileClip support.", "parameters": [], "return_type": null, "decorators": [ "pytest.mark.skipif(checks.IS_PYTHON_3_12 or not TORCH_1_9, reason='YOLOE with CLIP is not supported in Python 3.12')", "pytest.mark.skipif(checks.IS_PYTHON_3_8 and LINUX and ARM64, reason='YOLOE with CLIP is not supported in Python 3.8 and aarch64 Linux')" ], "complexity_score": 1, "dependencies": [ "contextlib", "csv", "urllib", "copy.copy", "pathlib.Path", "cv2", "numpy", "pytest", "torch", "PIL.Image", "tests.CFG", "tests.MODEL", "tests.MODELS", "tests.SOURCE", "tests.SOURCES_LIST", "tests.TASK_MODEL_DATA", "tests.TMP", "ultralytics.RTDETR", "ultralytics.YOLO", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASKS", "ultralytics.data.build.load_inference_source", "ultralytics.data.utils.check_det_dataset", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.DEFAULT_CFG_PATH", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.ONLINE", "ultralytics.utils.ROOT", "ultralytics.utils.WEIGHTS_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.YAML", "ultralytics.utils.checks", "ultralytics.utils.is_dir_writeable", "ultralytics.utils.is_github_action_running", "ultralytics.utils.downloads.download", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.nn.tasks.DetectionModel", "ultralytics.data.split.autosplit", "ultralytics.data.utils.HUBDatasetStats", "ultralytics.utils.downloads.zip_directory", "ultralytics.data.converter.coco80_to_coco91_class", "ultralytics.data.converter.convert_coco", "ultralytics.data.annotator.auto_annotate", "ultralytics.hub.utils.Events", "ultralytics.cfg.check_dict_alignment", "ultralytics.cfg.copy_default_cfg", "ultralytics.cfg.smart_value", "ultralytics.utils.get_git_branch", "ultralytics.utils.get_git_origin_url", "ultralytics.utils.get_ubuntu_version", "ultralytics.utils.is_github_action_running", "ultralytics.utils.benchmarks.ProfileModels", "ultralytics.nn.modules.conv.Conv", "ultralytics.utils.torch_utils.get_flops_with_torch_profiler", "ultralytics.utils.torch_utils.profile_ops", "ultralytics.utils.torch_utils.time_sync", "ultralytics.utils.ops.ltwh2xywh", "ultralytics.utils.ops.ltwh2xyxy", "ultralytics.utils.ops.make_divisible", "ultralytics.utils.ops.xywh2ltwh", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xywhn2xyxy", "ultralytics.utils.ops.xywhr2xyxyxyxy", "ultralytics.utils.ops.xyxy2ltwh", "ultralytics.utils.ops.xyxy2xywh", "ultralytics.utils.ops.xyxy2xywhn", "ultralytics.utils.ops.xyxyxyxy2xywhr", "ultralytics.utils.files.file_age", "ultralytics.utils.files.file_date", "ultralytics.utils.files.get_latest_run", "ultralytics.utils.files.spaces_in_path", "unittest.mock.MagicMock", "unittest.mock.patch", "ultralytics.utils.patches.torch_save", "ultralytics.nn.modules.conv.CBAM", "ultralytics.nn.modules.conv.Conv2", "ultralytics.nn.modules.conv.ConvTranspose", "ultralytics.nn.modules.conv.DWConvTranspose2d", "ultralytics.nn.modules.conv.Focus", "ultralytics.nn.modules.block.C1", "ultralytics.nn.modules.block.C3TR", "ultralytics.nn.modules.block.BottleneckCSP", "ultralytics.nn.modules.block.C3Ghost", "ultralytics.nn.modules.block.C3x", "ultralytics.hub.export_fmts_hub", "ultralytics.hub.logout", "ultralytics.hub.utils.smart_request", "ultralytics.data.augment.classify_augmentations", "ultralytics.models.yolo.world.train_world.WorldTrainerFromScratch", "numpy", "ultralytics.YOLOE", "ultralytics.models.yolo.yoloe.YOLOEVPSegPredictor", "ultralytics.models.yolo.yoloe.YOLOEPESegTrainer" ], "chunk_id": "function_test_yoloe_9c818796" }, { "content": "def test_yolov10():\n \"\"\"Test YOLOv10 model training, validation, and prediction functionality.\"\"\"\n model = YOLO(\"yolov10n.yaml\")\n # train/val/predict\n model.train(data=\"coco8.yaml\", epochs=1, imgsz=32, close_mosaic=1, cache=\"disk\")\n model.val(data=\"coco8.yaml\", imgsz=32)\n model.predict(imgsz=32, save_txt=True, save_crop=True, augment=True)\n model(SOURCE)", "chunk_type": "function", "name": "test_yolov10", "file_path": "ultralytics\\tests\\test_python.py", "start_line": 698, "end_line": 705, "start_col": 0, "end_col": 17, "parent_name": null, "docstring": "Test YOLOv10 model training, validation, and prediction functionality.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "contextlib", "csv", "urllib", "copy.copy", "pathlib.Path", "cv2", "numpy", "pytest", "torch", "PIL.Image", "tests.CFG", "tests.MODEL", "tests.MODELS", "tests.SOURCE", "tests.SOURCES_LIST", "tests.TASK_MODEL_DATA", "tests.TMP", "ultralytics.RTDETR", "ultralytics.YOLO", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASKS", "ultralytics.data.build.load_inference_source", "ultralytics.data.utils.check_det_dataset", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.DEFAULT_CFG_PATH", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.ONLINE", "ultralytics.utils.ROOT", "ultralytics.utils.WEIGHTS_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.YAML", "ultralytics.utils.checks", "ultralytics.utils.is_dir_writeable", "ultralytics.utils.is_github_action_running", "ultralytics.utils.downloads.download", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.nn.tasks.DetectionModel", "ultralytics.data.split.autosplit", "ultralytics.data.utils.HUBDatasetStats", "ultralytics.utils.downloads.zip_directory", "ultralytics.data.converter.coco80_to_coco91_class", "ultralytics.data.converter.convert_coco", "ultralytics.data.annotator.auto_annotate", "ultralytics.hub.utils.Events", "ultralytics.cfg.check_dict_alignment", "ultralytics.cfg.copy_default_cfg", "ultralytics.cfg.smart_value", "ultralytics.utils.get_git_branch", "ultralytics.utils.get_git_origin_url", "ultralytics.utils.get_ubuntu_version", "ultralytics.utils.is_github_action_running", "ultralytics.utils.benchmarks.ProfileModels", "ultralytics.nn.modules.conv.Conv", "ultralytics.utils.torch_utils.get_flops_with_torch_profiler", "ultralytics.utils.torch_utils.profile_ops", "ultralytics.utils.torch_utils.time_sync", "ultralytics.utils.ops.ltwh2xywh", "ultralytics.utils.ops.ltwh2xyxy", "ultralytics.utils.ops.make_divisible", "ultralytics.utils.ops.xywh2ltwh", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xywhn2xyxy", "ultralytics.utils.ops.xywhr2xyxyxyxy", "ultralytics.utils.ops.xyxy2ltwh", "ultralytics.utils.ops.xyxy2xywh", "ultralytics.utils.ops.xyxy2xywhn", "ultralytics.utils.ops.xyxyxyxy2xywhr", "ultralytics.utils.files.file_age", "ultralytics.utils.files.file_date", "ultralytics.utils.files.get_latest_run", "ultralytics.utils.files.spaces_in_path", "unittest.mock.MagicMock", "unittest.mock.patch", "ultralytics.utils.patches.torch_save", "ultralytics.nn.modules.conv.CBAM", "ultralytics.nn.modules.conv.Conv2", "ultralytics.nn.modules.conv.ConvTranspose", "ultralytics.nn.modules.conv.DWConvTranspose2d", "ultralytics.nn.modules.conv.Focus", "ultralytics.nn.modules.block.C1", "ultralytics.nn.modules.block.C3TR", "ultralytics.nn.modules.block.BottleneckCSP", "ultralytics.nn.modules.block.C3Ghost", "ultralytics.nn.modules.block.C3x", "ultralytics.hub.export_fmts_hub", "ultralytics.hub.logout", "ultralytics.hub.utils.smart_request", "ultralytics.data.augment.classify_augmentations", "ultralytics.models.yolo.world.train_world.WorldTrainerFromScratch", "numpy", "ultralytics.YOLOE", "ultralytics.models.yolo.yoloe.YOLOEVPSegPredictor", "ultralytics.models.yolo.yoloe.YOLOEPESegTrainer" ], "chunk_id": "function_test_yolov10_3a749ba2" }, { "content": "def test_multichannel():\n \"\"\"Test YOLO model multi-channel training, validation, and prediction functionality.\"\"\"\n model = YOLO(\"yolo11n.pt\")\n model.train(data=\"coco8-multispectral.yaml\", epochs=1, imgsz=32, close_mosaic=1, cache=\"disk\")\n model.val(data=\"coco8-multispectral.yaml\")\n im = np.zeros((32, 32, 10), dtype=np.uint8)\n model.predict(source=im, imgsz=32, save_txt=True, save_crop=True, augment=True)\n model.export(format=\"onnx\")", "chunk_type": "function", "name": "test_multichannel", "file_path": "ultralytics\\tests\\test_python.py", "start_line": 708, "end_line": 715, "start_col": 0, "end_col": 31, "parent_name": null, "docstring": "Test YOLO model multi-channel training, validation, and prediction functionality.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "contextlib", "csv", "urllib", "copy.copy", "pathlib.Path", "cv2", "numpy", "pytest", "torch", "PIL.Image", "tests.CFG", "tests.MODEL", "tests.MODELS", "tests.SOURCE", "tests.SOURCES_LIST", "tests.TASK_MODEL_DATA", "tests.TMP", "ultralytics.RTDETR", "ultralytics.YOLO", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASKS", "ultralytics.data.build.load_inference_source", "ultralytics.data.utils.check_det_dataset", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.DEFAULT_CFG_PATH", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.ONLINE", "ultralytics.utils.ROOT", "ultralytics.utils.WEIGHTS_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.YAML", "ultralytics.utils.checks", "ultralytics.utils.is_dir_writeable", "ultralytics.utils.is_github_action_running", "ultralytics.utils.downloads.download", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.nn.tasks.DetectionModel", "ultralytics.data.split.autosplit", "ultralytics.data.utils.HUBDatasetStats", "ultralytics.utils.downloads.zip_directory", "ultralytics.data.converter.coco80_to_coco91_class", "ultralytics.data.converter.convert_coco", "ultralytics.data.annotator.auto_annotate", "ultralytics.hub.utils.Events", "ultralytics.cfg.check_dict_alignment", "ultralytics.cfg.copy_default_cfg", "ultralytics.cfg.smart_value", "ultralytics.utils.get_git_branch", "ultralytics.utils.get_git_origin_url", "ultralytics.utils.get_ubuntu_version", "ultralytics.utils.is_github_action_running", "ultralytics.utils.benchmarks.ProfileModels", "ultralytics.nn.modules.conv.Conv", "ultralytics.utils.torch_utils.get_flops_with_torch_profiler", "ultralytics.utils.torch_utils.profile_ops", "ultralytics.utils.torch_utils.time_sync", "ultralytics.utils.ops.ltwh2xywh", "ultralytics.utils.ops.ltwh2xyxy", "ultralytics.utils.ops.make_divisible", "ultralytics.utils.ops.xywh2ltwh", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xywhn2xyxy", "ultralytics.utils.ops.xywhr2xyxyxyxy", "ultralytics.utils.ops.xyxy2ltwh", "ultralytics.utils.ops.xyxy2xywh", "ultralytics.utils.ops.xyxy2xywhn", "ultralytics.utils.ops.xyxyxyxy2xywhr", "ultralytics.utils.files.file_age", "ultralytics.utils.files.file_date", "ultralytics.utils.files.get_latest_run", "ultralytics.utils.files.spaces_in_path", "unittest.mock.MagicMock", "unittest.mock.patch", "ultralytics.utils.patches.torch_save", "ultralytics.nn.modules.conv.CBAM", "ultralytics.nn.modules.conv.Conv2", "ultralytics.nn.modules.conv.ConvTranspose", "ultralytics.nn.modules.conv.DWConvTranspose2d", "ultralytics.nn.modules.conv.Focus", "ultralytics.nn.modules.block.C1", "ultralytics.nn.modules.block.C3TR", "ultralytics.nn.modules.block.BottleneckCSP", "ultralytics.nn.modules.block.C3Ghost", "ultralytics.nn.modules.block.C3x", "ultralytics.hub.export_fmts_hub", "ultralytics.hub.logout", "ultralytics.hub.utils.smart_request", "ultralytics.data.augment.classify_augmentations", "ultralytics.models.yolo.world.train_world.WorldTrainerFromScratch", "numpy", "ultralytics.YOLOE", "ultralytics.models.yolo.yoloe.YOLOEVPSegPredictor", "ultralytics.models.yolo.yoloe.YOLOEPESegTrainer" ], "chunk_id": "function_test_multichannel_558d9fd1" }, { "content": "def test_grayscale(task: str, model: str, data: str) -> None:\n \"\"\"Test YOLO model grayscale training, validation, and prediction functionality.\"\"\"\n if task == \"classify\": # not support grayscale classification yet\n return\n grayscale_data = Path(TMP) / f\"{Path(data).stem}-grayscale.yaml\"\n data = check_det_dataset(data)\n data[\"channels\"] = 1 # add additional channels key for grayscale\n YAML.save(grayscale_data, data)\n # remove npy files in train/val splits if exists, might be created by previous tests\n for split in {\"train\", \"val\"}:\n for npy_file in (Path(data[\"path\"]) / data[split]).glob(\"*.npy\"):\n npy_file.unlink()\n\n model = YOLO(model)\n model.train(data=grayscale_data, epochs=1, imgsz=32, close_mosaic=1)\n model.val(data=grayscale_data)\n im = np.zeros((32, 32, 1), dtype=np.uint8)\n model.predict(source=im, imgsz=32, save_txt=True, save_crop=True, augment=True)\n export_model = model.export(format=\"onnx\")\n\n model = YOLO(export_model, task=task)\n model.predict(source=im, imgsz=32)", "chunk_type": "function", "name": "test_grayscale", "file_path": "ultralytics\\tests\\test_python.py", "start_line": 719, "end_line": 740, "start_col": 0, "end_col": 38, "parent_name": null, "docstring": "Test YOLO model grayscale training, validation, and prediction functionality.", "parameters": [ "task: str", "model: str", "data: str" ], "return_type": "None", "decorators": [ "pytest.mark.parametrize('task,model,data', TASK_MODEL_DATA)" ], "complexity_score": 4, "dependencies": [ "contextlib", "csv", "urllib", "copy.copy", "pathlib.Path", "cv2", "numpy", "pytest", "torch", "PIL.Image", "tests.CFG", "tests.MODEL", "tests.MODELS", "tests.SOURCE", "tests.SOURCES_LIST", "tests.TASK_MODEL_DATA", "tests.TMP", "ultralytics.RTDETR", "ultralytics.YOLO", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASKS", "ultralytics.data.build.load_inference_source", "ultralytics.data.utils.check_det_dataset", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.DEFAULT_CFG_PATH", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.ONLINE", "ultralytics.utils.ROOT", "ultralytics.utils.WEIGHTS_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.YAML", "ultralytics.utils.checks", "ultralytics.utils.is_dir_writeable", "ultralytics.utils.is_github_action_running", "ultralytics.utils.downloads.download", "ultralytics.utils.torch_utils.TORCH_1_9", "ultralytics.nn.tasks.DetectionModel", "ultralytics.data.split.autosplit", "ultralytics.data.utils.HUBDatasetStats", "ultralytics.utils.downloads.zip_directory", "ultralytics.data.converter.coco80_to_coco91_class", "ultralytics.data.converter.convert_coco", "ultralytics.data.annotator.auto_annotate", "ultralytics.hub.utils.Events", "ultralytics.cfg.check_dict_alignment", "ultralytics.cfg.copy_default_cfg", "ultralytics.cfg.smart_value", "ultralytics.utils.get_git_branch", "ultralytics.utils.get_git_origin_url", "ultralytics.utils.get_ubuntu_version", "ultralytics.utils.is_github_action_running", "ultralytics.utils.benchmarks.ProfileModels", "ultralytics.nn.modules.conv.Conv", "ultralytics.utils.torch_utils.get_flops_with_torch_profiler", "ultralytics.utils.torch_utils.profile_ops", "ultralytics.utils.torch_utils.time_sync", "ultralytics.utils.ops.ltwh2xywh", "ultralytics.utils.ops.ltwh2xyxy", "ultralytics.utils.ops.make_divisible", "ultralytics.utils.ops.xywh2ltwh", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xywhn2xyxy", "ultralytics.utils.ops.xywhr2xyxyxyxy", "ultralytics.utils.ops.xyxy2ltwh", "ultralytics.utils.ops.xyxy2xywh", "ultralytics.utils.ops.xyxy2xywhn", "ultralytics.utils.ops.xyxyxyxy2xywhr", "ultralytics.utils.files.file_age", "ultralytics.utils.files.file_date", "ultralytics.utils.files.get_latest_run", "ultralytics.utils.files.spaces_in_path", "unittest.mock.MagicMock", "unittest.mock.patch", "ultralytics.utils.patches.torch_save", "ultralytics.nn.modules.conv.CBAM", "ultralytics.nn.modules.conv.Conv2", "ultralytics.nn.modules.conv.ConvTranspose", "ultralytics.nn.modules.conv.DWConvTranspose2d", "ultralytics.nn.modules.conv.Focus", "ultralytics.nn.modules.block.C1", "ultralytics.nn.modules.block.C3TR", "ultralytics.nn.modules.block.BottleneckCSP", "ultralytics.nn.modules.block.C3Ghost", "ultralytics.nn.modules.block.C3x", "ultralytics.hub.export_fmts_hub", "ultralytics.hub.logout", "ultralytics.hub.utils.smart_request", "ultralytics.data.augment.classify_augmentations", "ultralytics.models.yolo.world.train_world.WorldTrainerFromScratch", "numpy", "ultralytics.YOLOE", "ultralytics.models.yolo.yoloe.YOLOEVPSegPredictor", "ultralytics.models.yolo.yoloe.YOLOEPESegTrainer" ], "chunk_id": "function_test_grayscale_6661adc8" }, { "content": "import os", "chunk_type": "import", "name": "os", "file_path": "ultralytics\\tests\\test_solutions.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 9, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_os_d40e341e" }, { "content": "from unittest.mock import patch", "chunk_type": "import", "name": "patch", "file_path": "ultralytics\\tests\\test_solutions.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 31, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_patch_4a550038" }, { "content": "import cv2", "chunk_type": "import", "name": "cv2", "file_path": "ultralytics\\tests\\test_solutions.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 10, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_cv2_8b786bcb" }, { "content": "import numpy as np", "chunk_type": "import", "name": "numpy", "file_path": "ultralytics\\tests\\test_solutions.py", "start_line": 10, "end_line": 10, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_numpy_205915f9" }, { "content": "import pytest", "chunk_type": "import", "name": "pytest", "file_path": "ultralytics\\tests\\test_solutions.py", "start_line": 11, "end_line": 11, "start_col": 0, "end_col": 13, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_pytest_84d34f9a" }, { "content": "from tests import MODEL, TMP", "chunk_type": "import", "name": "MODEL, TMP", "file_path": "ultralytics\\tests\\test_solutions.py", "start_line": 13, "end_line": 13, "start_col": 0, "end_col": 28, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_MODEL, TMP_df63cad3" }, { "content": "from ultralytics import solutions", "chunk_type": "import", "name": "solutions", "file_path": "ultralytics\\tests\\test_solutions.py", "start_line": 14, "end_line": 14, "start_col": 0, "end_col": 33, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_solutions_0d3687c8" }, { "content": "from ultralytics.utils import ASSETS_URL, IS_RASPBERRYPI, checks", "chunk_type": "import", "name": "ASSETS_URL, IS_RASPBERRYPI, checks", "file_path": "ultralytics\\tests\\test_solutions.py", "start_line": 15, "end_line": 15, "start_col": 0, "end_col": 64, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_ASSETS_URL, IS_RASPBERRYPI, checks_4e02e3ba" }, { "content": "from ultralytics.utils.downloads import safe_download", "chunk_type": "import", "name": "safe_download", "file_path": "ultralytics\\tests\\test_solutions.py", "start_line": 16, "end_line": 16, "start_col": 0, "end_col": 53, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_safe_download_b9205279" }, { "content": "SHOW = False", "chunk_type": "variable", "name": "SHOW", "file_path": "ultralytics\\tests\\test_solutions.py", "start_line": 19, "end_line": 19, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_SHOW_b8770977" }, { "content": "DEMO_VIDEO = \"solutions_ci_demo.mp4\" # for all the solutions, except workout, object cropping and parking management", "chunk_type": "variable", "name": "DEMO_VIDEO", "file_path": "ultralytics\\tests\\test_solutions.py", "start_line": 20, "end_line": 20, "start_col": 0, "end_col": 36, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_DEMO_VIDEO_8272d882" }, { "content": "CROP_VIDEO = \"decelera_landscape_min.mov\" # for object cropping solution", "chunk_type": "variable", "name": "CROP_VIDEO", "file_path": "ultralytics\\tests\\test_solutions.py", "start_line": 21, "end_line": 21, "start_col": 0, "end_col": 41, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_CROP_VIDEO_f6a93be5" }, { "content": "POSE_VIDEO = \"solution_ci_pose_demo.mp4\" # only for workouts monitoring solution", "chunk_type": "variable", "name": "POSE_VIDEO", "file_path": "ultralytics\\tests\\test_solutions.py", "start_line": 22, "end_line": 22, "start_col": 0, "end_col": 40, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_POSE_VIDEO_f84e9c8f" }, { "content": "PARKING_VIDEO = \"solution_ci_parking_demo.mp4\" # only for parking management solution", "chunk_type": "variable", "name": "PARKING_VIDEO", "file_path": "ultralytics\\tests\\test_solutions.py", "start_line": 23, "end_line": 23, "start_col": 0, "end_col": 46, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_PARKING_VIDEO_747a50be" }, { "content": "PARKING_AREAS_JSON = \"solution_ci_parking_areas.json\" # only for parking management solution", "chunk_type": "variable", "name": "PARKING_AREAS_JSON", "file_path": "ultralytics\\tests\\test_solutions.py", "start_line": 24, "end_line": 24, "start_col": 0, "end_col": 53, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_PARKING_AREAS_JSON_2ae3bc2b" }, { "content": "PARKING_MODEL = \"solutions_ci_parking_model.pt\" # only for parking management solution", "chunk_type": "variable", "name": "PARKING_MODEL", "file_path": "ultralytics\\tests\\test_solutions.py", "start_line": 25, "end_line": 25, "start_col": 0, "end_col": 47, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_PARKING_MODEL_24bee7a4" }, { "content": "VERTICAL_VIDEO = \"solution_vertical_demo.mp4\" # only for vertical line counting", "chunk_type": "variable", "name": "VERTICAL_VIDEO", "file_path": "ultralytics\\tests\\test_solutions.py", "start_line": 26, "end_line": 26, "start_col": 0, "end_col": 45, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_VERTICAL_VIDEO_3087c5e1" }, { "content": "REGION = [(10, 200), (540, 200), (540, 180), (10, 180)] # for object counting, speed estimation and queue management", "chunk_type": "variable", "name": "REGION", "file_path": "ultralytics\\tests\\test_solutions.py", "start_line": 27, "end_line": 27, "start_col": 0, "end_col": 55, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_REGION_ab8b4d04" }, { "content": "HORIZONTAL_LINE = [(10, 200), (540, 200)] # for object counting", "chunk_type": "variable", "name": "HORIZONTAL_LINE", "file_path": "ultralytics\\tests\\test_solutions.py", "start_line": 28, "end_line": 28, "start_col": 0, "end_col": 41, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_HORIZONTAL_LINE_fb9e2c4f" }, { "content": "VERTICAL_LINE = [(320, 0), (320, 400)] # for object counting", "chunk_type": "variable", "name": "VERTICAL_LINE", "file_path": "ultralytics\\tests\\test_solutions.py", "start_line": 29, "end_line": 29, "start_col": 0, "end_col": 38, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_VERTICAL_LINE_15f34f0c" }, { "content": "SOLUTIONS = [\n (\n \"ObjectCounter\",\n solutions.ObjectCounter,\n False,\n DEMO_VIDEO,\n {\"region\": REGION, \"model\": MODEL, \"show\": SHOW},\n ),\n (\n \"ObjectCounter\",\n solutions.ObjectCounter,\n False,\n DEMO_VIDEO,\n {\"region\": HORIZONTAL_LINE, \"model\": MODEL, \"show\": SHOW},\n ),\n (\n \"ObjectCounterVertical\",\n solutions.ObjectCounter,\n False,\n DEMO_VIDEO,\n {\"region\": VERTICAL_LINE, \"model\": MODEL, \"show\": SHOW},\n ),\n (\n \"ObjectCounterwithOBB\",\n solutions.ObjectCounter,\n False,\n DEMO_VIDEO,\n {\"region\": REGION, \"model\": \"yolo11n-obb.pt\", \"show\": SHOW},\n ),\n (\n \"Heatmap\",\n solutions.Heatmap,\n False,\n DEMO_VIDEO,\n {\"colormap\": cv2.COLORMAP_PARULA, \"model\": MODEL, \"show\": SHOW, \"region\": None},\n ),\n (\n \"HeatmapWithRegion\",\n solutions.Heatmap,\n False,\n DEMO_VIDEO,\n {\"colormap\": cv2.COLORMAP_PARULA, \"region\": REGION, \"model\": MODEL, \"show\": SHOW},\n ),\n (\n \"SpeedEstimator\",\n solutions.SpeedEstimator,\n False,\n DEMO_VIDEO,\n {\"region\": REGION, \"model\": MODEL, \"show\": SHOW},\n ),\n (\n \"QueueManager\",\n solutions.QueueManager,\n False,\n DEMO_VIDEO,\n {\"region\": REGION, \"model\": MODEL, \"show\": SHOW},\n ),\n (\n \"LineAnalytics\",\n solutions.Analytics,\n True,\n DEMO_VIDEO,\n {\"analytics_type\": \"line\", \"model\": MODEL, \"show\": SHOW, \"figsize\": (6.4, 3.2)},\n ),\n (\n \"PieAnalytics\",\n solutions.Analytics,\n True,\n DEMO_VIDEO,\n {\"analytics_type\": \"pie\", \"model\": MODEL, \"show\": SHOW, \"figsize\": (6.4, 3.2)},\n ),\n (\n \"BarAnalytics\",\n solutions.Analytics,\n True,\n DEMO_VIDEO,\n {\"analytics_type\": \"bar\", \"model\": MODEL, \"show\": SHOW, \"figsize\": (6.4, 3.2)},\n ),\n (\n \"AreaAnalytics\",\n solutions.Analytics,\n True,\n DEMO_VIDEO,\n {\"analytics_type\": \"area\", \"model\": MODEL, \"show\": SHOW, \"figsize\": (6.4, 3.2)},\n ),\n (\"TrackZone\", solutions.TrackZone, False, DEMO_VIDEO, {\"region\": REGION, \"model\": MODEL, \"show\": SHOW}),\n (\n \"ObjectCropper\",\n solutions.ObjectCropper,\n False,\n CROP_VIDEO,\n {\"crop_dir\": str(TMP / \"cropped-detections\"), \"model\": MODEL, \"show\": SHOW},\n ),\n (\n \"ObjectBlurrer\",\n solutions.ObjectBlurrer,\n False,\n DEMO_VIDEO,\n {\"blur_ratio\": 0.02, \"model\": MODEL, \"show\": SHOW},\n ),\n (\n \"InstanceSegmentation\",\n solutions.InstanceSegmentation,\n False,\n DEMO_VIDEO,\n {\"model\": \"yolo11n-seg.pt\", \"show\": SHOW},\n ),\n (\"VisionEye\", solutions.VisionEye, False, DEMO_VIDEO, {\"model\": MODEL, \"show\": SHOW}),\n (\n \"RegionCounter\",\n solutions.RegionCounter,\n False,\n DEMO_VIDEO,\n {\"region\": REGION, \"model\": MODEL, \"show\": SHOW},\n ),\n (\"AIGym\", solutions.AIGym, False, POSE_VIDEO, {\"kpts\": [6, 8, 10], \"show\": SHOW}),\n (\n \"ParkingManager\",\n solutions.ParkingManagement,\n False,\n PARKING_VIDEO,\n {\"model\": str(TMP / PARKING_MODEL), \"show\": SHOW, \"json_file\": str(TMP / PARKING_AREAS_JSON)},\n ),\n (\n \"StreamlitInference\",\n solutions.Inference,\n False,\n None, # streamlit application doesn't require video file\n {}, # streamlit application doesn't accept arguments\n ),\n]", "chunk_type": "variable", "name": "SOLUTIONS", "file_path": "ultralytics\\tests\\test_solutions.py", "start_line": 32, "end_line": 162, "start_col": 0, "end_col": 1, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_SOLUTIONS_ae6454cb" }, { "content": "def process_video(solution, video_path: str, needs_frame_count: bool = False):\n \"\"\"Process video with solution, feeding frames and optional frame count to the solution instance.\"\"\"\n cap = cv2.VideoCapture(video_path)\n assert cap.isOpened(), f\"Error reading video file {video_path}\"\n\n frame_count = 0\n while cap.isOpened():\n success, im0 = cap.read()\n if not success:\n break\n frame_count += 1\n im_copy = im0.copy()\n args = [im_copy, frame_count] if needs_frame_count else [im_copy]\n _ = solution(*args)\n\n cap.release()", "chunk_type": "function", "name": "process_video", "file_path": "ultralytics\\tests\\test_solutions.py", "start_line": 165, "end_line": 180, "start_col": 0, "end_col": 17, "parent_name": null, "docstring": "Process video with solution, feeding frames and optional frame count to the solution instance.", "parameters": [ "solution", "video_path: str", "needs_frame_count: bool" ], "return_type": null, "decorators": [], "complexity_score": 3, "dependencies": [ "os", "unittest.mock.patch", "cv2", "numpy", "pytest", "tests.MODEL", "tests.TMP", "ultralytics.solutions", "ultralytics.utils.ASSETS_URL", "ultralytics.utils.IS_RASPBERRYPI", "ultralytics.utils.checks", "ultralytics.utils.downloads.safe_download", "io", "PIL.Image", "ultralytics.solutions.solutions.SolutionResults" ], "chunk_id": "function_process_video_c93f494c" }, { "content": "def test_solution(name, solution_class, needs_frame_count, video, kwargs):\n \"\"\"Test individual Ultralytics solution with video processing and parameter validation.\"\"\"\n if video:\n if name != \"ObjectCounterVertical\":\n safe_download(url=f\"{ASSETS_URL}/{video}\", dir=TMP)\n else:\n safe_download(url=f\"{ASSETS_URL}/{VERTICAL_VIDEO}\", dir=TMP)\n if name == \"ParkingManager\":\n safe_download(url=f\"{ASSETS_URL}/{PARKING_AREAS_JSON}\", dir=TMP)\n safe_download(url=f\"{ASSETS_URL}/{PARKING_MODEL}\", dir=TMP)\n elif name == \"StreamlitInference\":\n if checks.check_imshow(): # do not merge with elif above\n solution_class(**kwargs).inference() # requires interactive GUI environment\n return\n\n video = VERTICAL_VIDEO if name == \"ObjectCounterVertical\" else video\n process_video(\n solution=solution_class(**kwargs),\n video_path=str(TMP / video),\n needs_frame_count=needs_frame_count,\n )", "chunk_type": "function", "name": "test_solution", "file_path": "ultralytics\\tests\\test_solutions.py", "start_line": 185, "end_line": 205, "start_col": 0, "end_col": 5, "parent_name": null, "docstring": "Test individual Ultralytics solution with video processing and parameter validation.", "parameters": [ "name", "solution_class", "needs_frame_count", "video", "kwargs" ], "return_type": null, "decorators": [ "pytest.mark.skipif(IS_RASPBERRYPI, reason='Disabled for testing due to --slow test errors after YOLOE PR.')", "pytest.mark.parametrize('name, solution_class, needs_frame_count, video, kwargs', SOLUTIONS)" ], "complexity_score": 6, "dependencies": [ "os", "unittest.mock.patch", "cv2", "numpy", "pytest", "tests.MODEL", "tests.TMP", "ultralytics.solutions", "ultralytics.utils.ASSETS_URL", "ultralytics.utils.IS_RASPBERRYPI", "ultralytics.utils.checks", "ultralytics.utils.downloads.safe_download", "io", "PIL.Image", "ultralytics.solutions.solutions.SolutionResults" ], "chunk_id": "function_test_solution_ee677e77" }, { "content": "def test_similarity_search():\n \"\"\"Test similarity search solution with sample images and text query.\"\"\"\n safe_download(f\"{ASSETS_URL}/4-imgs-similaritysearch.zip\", dir=TMP) # 4 dog images for testing in a zip file\n searcher = solutions.VisualAISearch(data=str(TMP / \"4-imgs-similaritysearch\"))\n _ = searcher(\"a dog sitting on a bench\") # Returns the results in format \"- img name | similarity score\"", "chunk_type": "function", "name": "test_similarity_search", "file_path": "ultralytics\\tests\\test_solutions.py", "start_line": 210, "end_line": 214, "start_col": 0, "end_col": 44, "parent_name": null, "docstring": "Test similarity search solution with sample images and text query.", "parameters": [], "return_type": null, "decorators": [ "pytest.mark.skipif(checks.IS_PYTHON_3_8, reason='Disabled due to unsupported CLIP dependencies.')", "pytest.mark.skipif(IS_RASPBERRYPI, reason='Disabled due to slow performance on Raspberry Pi.')" ], "complexity_score": 1, "dependencies": [ "os", "unittest.mock.patch", "cv2", "numpy", "pytest", "tests.MODEL", "tests.TMP", "ultralytics.solutions", "ultralytics.utils.ASSETS_URL", "ultralytics.utils.IS_RASPBERRYPI", "ultralytics.utils.checks", "ultralytics.utils.downloads.safe_download", "io", "PIL.Image", "ultralytics.solutions.solutions.SolutionResults" ], "chunk_id": "function_test_similarity_search_3e64ae55" }, { "content": "def test_left_click_selection():\n \"\"\"Test distance calculation left click selection functionality.\"\"\"\n dc = solutions.DistanceCalculation()\n dc.boxes, dc.track_ids = [[10, 10, 50, 50]], [1]\n dc.mouse_event_for_distance(cv2.EVENT_LBUTTONDOWN, 30, 30, None, None)\n assert 1 in dc.selected_boxes", "chunk_type": "function", "name": "test_left_click_selection", "file_path": "ultralytics\\tests\\test_solutions.py", "start_line": 217, "end_line": 222, "start_col": 0, "end_col": 33, "parent_name": null, "docstring": "Test distance calculation left click selection functionality.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "os", "unittest.mock.patch", "cv2", "numpy", "pytest", "tests.MODEL", "tests.TMP", "ultralytics.solutions", "ultralytics.utils.ASSETS_URL", "ultralytics.utils.IS_RASPBERRYPI", "ultralytics.utils.checks", "ultralytics.utils.downloads.safe_download", "io", "PIL.Image", "ultralytics.solutions.solutions.SolutionResults" ], "chunk_id": "function_test_left_click_selection_c0269619" }, { "content": "def test_right_click_reset():\n \"\"\"Test distance calculation right click reset functionality.\"\"\"\n dc = solutions.DistanceCalculation()\n dc.selected_boxes, dc.left_mouse_count = {1: [10, 10, 50, 50]}, 1\n dc.mouse_event_for_distance(cv2.EVENT_RBUTTONDOWN, 0, 0, None, None)\n assert dc.selected_boxes == {}\n assert dc.left_mouse_count == 0", "chunk_type": "function", "name": "test_right_click_reset", "file_path": "ultralytics\\tests\\test_solutions.py", "start_line": 225, "end_line": 231, "start_col": 0, "end_col": 35, "parent_name": null, "docstring": "Test distance calculation right click reset functionality.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "os", "unittest.mock.patch", "cv2", "numpy", "pytest", "tests.MODEL", "tests.TMP", "ultralytics.solutions", "ultralytics.utils.ASSETS_URL", "ultralytics.utils.IS_RASPBERRYPI", "ultralytics.utils.checks", "ultralytics.utils.downloads.safe_download", "io", "PIL.Image", "ultralytics.solutions.solutions.SolutionResults" ], "chunk_id": "function_test_right_click_reset_8ad368f8" }, { "content": "def test_parking_json_none():\n \"\"\"Test that ParkingManagement handles missing JSON gracefully.\"\"\"\n im0 = np.zeros((640, 480, 3), dtype=np.uint8)\n try:\n parkingmanager = solutions.ParkingManagement(json_path=None)\n parkingmanager(im0)\n except ValueError:\n pytest.skip(\"Skipping test due to missing JSON.\")", "chunk_type": "function", "name": "test_parking_json_none", "file_path": "ultralytics\\tests\\test_solutions.py", "start_line": 234, "end_line": 241, "start_col": 0, "end_col": 57, "parent_name": null, "docstring": "Test that ParkingManagement handles missing JSON gracefully.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 2, "dependencies": [ "os", "unittest.mock.patch", "cv2", "numpy", "pytest", "tests.MODEL", "tests.TMP", "ultralytics.solutions", "ultralytics.utils.ASSETS_URL", "ultralytics.utils.IS_RASPBERRYPI", "ultralytics.utils.checks", "ultralytics.utils.downloads.safe_download", "io", "PIL.Image", "ultralytics.solutions.solutions.SolutionResults" ], "chunk_id": "function_test_parking_json_none_a4dde1bf" }, { "content": "def test_analytics_graph_not_supported():\n \"\"\"Test that unsupported analytics type raises ModuleNotFoundError.\"\"\"\n try:\n analytics = solutions.Analytics(analytics_type=\"test\") # 'test' is unsupported\n analytics.process(im0=np.zeros((640, 480, 3), dtype=np.uint8), frame_number=0)\n assert False, \"Expected ModuleNotFoundError for unsupported chart type\"\n except ModuleNotFoundError as e:\n assert \"test chart is not supported\" in str(e)", "chunk_type": "function", "name": "test_analytics_graph_not_supported", "file_path": "ultralytics\\tests\\test_solutions.py", "start_line": 244, "end_line": 251, "start_col": 0, "end_col": 54, "parent_name": null, "docstring": "Test that unsupported analytics type raises ModuleNotFoundError.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 2, "dependencies": [ "os", "unittest.mock.patch", "cv2", "numpy", "pytest", "tests.MODEL", "tests.TMP", "ultralytics.solutions", "ultralytics.utils.ASSETS_URL", "ultralytics.utils.IS_RASPBERRYPI", "ultralytics.utils.checks", "ultralytics.utils.downloads.safe_download", "io", "PIL.Image", "ultralytics.solutions.solutions.SolutionResults" ], "chunk_id": "function_test_analytics_graph_not_supported_055757a3" }, { "content": "def test_area_chart_padding():\n \"\"\"Test area chart graph update with dynamic class padding logic.\"\"\"\n analytics = solutions.Analytics(analytics_type=\"area\")\n analytics.update_graph(frame_number=1, count_dict={\"car\": 2}, plot=\"area\")\n plot_im = analytics.update_graph(frame_number=2, count_dict={\"car\": 3, \"person\": 1}, plot=\"area\")\n assert plot_im is not None", "chunk_type": "function", "name": "test_area_chart_padding", "file_path": "ultralytics\\tests\\test_solutions.py", "start_line": 254, "end_line": 259, "start_col": 0, "end_col": 30, "parent_name": null, "docstring": "Test area chart graph update with dynamic class padding logic.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "os", "unittest.mock.patch", "cv2", "numpy", "pytest", "tests.MODEL", "tests.TMP", "ultralytics.solutions", "ultralytics.utils.ASSETS_URL", "ultralytics.utils.IS_RASPBERRYPI", "ultralytics.utils.checks", "ultralytics.utils.downloads.safe_download", "io", "PIL.Image", "ultralytics.solutions.solutions.SolutionResults" ], "chunk_id": "function_test_area_chart_padding_dbc56bc4" }, { "content": "def test_config_update_method_with_invalid_argument():\n \"\"\"Test that update() raises ValueError for invalid config keys.\"\"\"\n obj = solutions.config.SolutionConfig()\n try:\n obj.update(invalid_key=123)\n assert False, \"Expected ValueError for invalid update argument\"\n except ValueError as e:\n assert \"is not a valid solution argument\" in str(e)", "chunk_type": "function", "name": "test_config_update_method_with_invalid_argument", "file_path": "ultralytics\\tests\\test_solutions.py", "start_line": 262, "end_line": 269, "start_col": 0, "end_col": 59, "parent_name": null, "docstring": "Test that update() raises ValueError for invalid config keys.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 2, "dependencies": [ "os", "unittest.mock.patch", "cv2", "numpy", "pytest", "tests.MODEL", "tests.TMP", "ultralytics.solutions", "ultralytics.utils.ASSETS_URL", "ultralytics.utils.IS_RASPBERRYPI", "ultralytics.utils.checks", "ultralytics.utils.downloads.safe_download", "io", "PIL.Image", "ultralytics.solutions.solutions.SolutionResults" ], "chunk_id": "function_test_config_update_method_with_invalid_argument_a78da71f" }, { "content": "def test_plot_with_no_masks():\n \"\"\"Test that instance segmentation handles cases with no masks.\"\"\"\n im0 = np.zeros((640, 480, 3), dtype=np.uint8)\n isegment = solutions.InstanceSegmentation(model=\"yolo11n-seg.pt\")\n results = isegment(im0)\n assert results.plot_im is not None", "chunk_type": "function", "name": "test_plot_with_no_masks", "file_path": "ultralytics\\tests\\test_solutions.py", "start_line": 272, "end_line": 277, "start_col": 0, "end_col": 38, "parent_name": null, "docstring": "Test that instance segmentation handles cases with no masks.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "os", "unittest.mock.patch", "cv2", "numpy", "pytest", "tests.MODEL", "tests.TMP", "ultralytics.solutions", "ultralytics.utils.ASSETS_URL", "ultralytics.utils.IS_RASPBERRYPI", "ultralytics.utils.checks", "ultralytics.utils.downloads.safe_download", "io", "PIL.Image", "ultralytics.solutions.solutions.SolutionResults" ], "chunk_id": "function_test_plot_with_no_masks_faa962c6" }, { "content": "def test_streamlit_handle_video_upload_creates_file():\n \"\"\"Test Streamlit video upload logic saves file correctly.\"\"\"\n import io\n\n fake_file = io.BytesIO(b\"fake video content\")\n fake_file.read = fake_file.getvalue\n if fake_file is not None:\n g = io.BytesIO(fake_file.read())\n with open(\"ultralytics.mp4\", \"wb\") as out:\n out.write(g.read())\n output_path = \"ultralytics.mp4\"\n else:\n output_path = None\n assert output_path == \"ultralytics.mp4\"\n assert os.path.exists(\"ultralytics.mp4\")\n with open(\"ultralytics.mp4\", \"rb\") as f:\n assert f.read() == b\"fake video content\"\n os.remove(\"ultralytics.mp4\")", "chunk_type": "function", "name": "test_streamlit_handle_video_upload_creates_file", "file_path": "ultralytics\\tests\\test_solutions.py", "start_line": 280, "end_line": 297, "start_col": 0, "end_col": 32, "parent_name": null, "docstring": "Test Streamlit video upload logic saves file correctly.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 2, "dependencies": [ "os", "unittest.mock.patch", "cv2", "numpy", "pytest", "tests.MODEL", "tests.TMP", "ultralytics.solutions", "ultralytics.utils.ASSETS_URL", "ultralytics.utils.IS_RASPBERRYPI", "ultralytics.utils.checks", "ultralytics.utils.downloads.safe_download", "io", "PIL.Image", "ultralytics.solutions.solutions.SolutionResults" ], "chunk_id": "function_test_streamlit_handle_video_upload_creates_file_9422cf47" }, { "content": "def test_similarity_search_app_init():\n \"\"\"Test SearchApp initializes with required attributes.\"\"\"\n app = solutions.SearchApp(device=\"cpu\")\n assert hasattr(app, \"searcher\")\n assert hasattr(app, \"run\")", "chunk_type": "function", "name": "test_similarity_search_app_init", "file_path": "ultralytics\\tests\\test_solutions.py", "start_line": 302, "end_line": 306, "start_col": 0, "end_col": 30, "parent_name": null, "docstring": "Test SearchApp initializes with required attributes.", "parameters": [], "return_type": null, "decorators": [ "pytest.mark.skipif(checks.IS_PYTHON_3_8, reason='Disabled due to unsupported CLIP dependencies.')", "pytest.mark.skipif(IS_RASPBERRYPI, reason='Disabled due to slow performance on Raspberry Pi.')" ], "complexity_score": 1, "dependencies": [ "os", "unittest.mock.patch", "cv2", "numpy", "pytest", "tests.MODEL", "tests.TMP", "ultralytics.solutions", "ultralytics.utils.ASSETS_URL", "ultralytics.utils.IS_RASPBERRYPI", "ultralytics.utils.checks", "ultralytics.utils.downloads.safe_download", "io", "PIL.Image", "ultralytics.solutions.solutions.SolutionResults" ], "chunk_id": "function_test_similarity_search_app_init_2cd486cd" }, { "content": "def test_similarity_search_complete(tmp_path):\n \"\"\"Test VisualAISearch end-to-end with sample image and query.\"\"\"\n from PIL import Image\n\n image_dir = tmp_path / \"images\"\n os.makedirs(image_dir, exist_ok=True)\n for i in range(2):\n img = Image.fromarray(np.uint8(np.random.rand(224, 224, 3) * 255))\n img.save(image_dir / f\"test_image_{i}.jpg\")\n searcher = solutions.VisualAISearch(data=str(image_dir))\n results = searcher(\"a red and white object\")\n assert results", "chunk_type": "function", "name": "test_similarity_search_complete", "file_path": "ultralytics\\tests\\test_solutions.py", "start_line": 310, "end_line": 321, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": "Test VisualAISearch end-to-end with sample image and query.", "parameters": [ "tmp_path" ], "return_type": null, "decorators": [ "pytest.mark.skipif(IS_RASPBERRYPI, reason='Disabled due to slow performance on Raspberry Pi.')" ], "complexity_score": 2, "dependencies": [ "os", "unittest.mock.patch", "cv2", "numpy", "pytest", "tests.MODEL", "tests.TMP", "ultralytics.solutions", "ultralytics.utils.ASSETS_URL", "ultralytics.utils.IS_RASPBERRYPI", "ultralytics.utils.checks", "ultralytics.utils.downloads.safe_download", "io", "PIL.Image", "ultralytics.solutions.solutions.SolutionResults" ], "chunk_id": "function_test_similarity_search_complete_0f8b09d0" }, { "content": "def test_distance_calculation_process_method():\n \"\"\"Test DistanceCalculation.process() computes distance between selected boxes.\"\"\"\n from ultralytics.solutions.solutions import SolutionResults\n\n dc = solutions.DistanceCalculation()\n dc.boxes, dc.track_ids, dc.clss, dc.confs = (\n [[100, 100, 200, 200], [300, 300, 400, 400]],\n [1, 2],\n [0, 0],\n [0.9, 0.95],\n )\n dc.selected_boxes = {1: dc.boxes[0], 2: dc.boxes[1]}\n frame = np.zeros((480, 640, 3), dtype=np.uint8)\n with patch.object(dc, \"extract_tracks\"), patch.object(dc, \"display_output\"), patch(\"cv2.setMouseCallback\"):\n result = dc.process(frame)\n assert isinstance(result, SolutionResults)\n assert result.total_tracks == 2\n assert result.pixels_distance > 0", "chunk_type": "function", "name": "test_distance_calculation_process_method", "file_path": "ultralytics\\tests\\test_solutions.py", "start_line": 324, "end_line": 341, "start_col": 0, "end_col": 37, "parent_name": null, "docstring": "Test DistanceCalculation.process() computes distance between selected boxes.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "os", "unittest.mock.patch", "cv2", "numpy", "pytest", "tests.MODEL", "tests.TMP", "ultralytics.solutions", "ultralytics.utils.ASSETS_URL", "ultralytics.utils.IS_RASPBERRYPI", "ultralytics.utils.checks", "ultralytics.utils.downloads.safe_download", "io", "PIL.Image", "ultralytics.solutions.solutions.SolutionResults" ], "chunk_id": "function_test_distance_calculation_process_method_6fc9630c" }, { "content": "def test_object_crop_with_show_True():\n \"\"\"Test ObjectCropper init with show=True to cover display warning.\"\"\"\n solutions.ObjectCropper(show=True)", "chunk_type": "function", "name": "test_object_crop_with_show_True", "file_path": "ultralytics\\tests\\test_solutions.py", "start_line": 344, "end_line": 346, "start_col": 0, "end_col": 38, "parent_name": null, "docstring": "Test ObjectCropper init with show=True to cover display warning.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "os", "unittest.mock.patch", "cv2", "numpy", "pytest", "tests.MODEL", "tests.TMP", "ultralytics.solutions", "ultralytics.utils.ASSETS_URL", "ultralytics.utils.IS_RASPBERRYPI", "ultralytics.utils.checks", "ultralytics.utils.downloads.safe_download", "io", "PIL.Image", "ultralytics.solutions.solutions.SolutionResults" ], "chunk_id": "function_test_object_crop_with_show_True_6d292e3f" }, { "content": "def test_display_output_method():\n \"\"\"Test that display_output triggers imshow, waitKey, and destroyAllWindows when enabled.\"\"\"\n counter = solutions.ObjectCounter(show=True)\n counter.env_check = True\n frame = np.zeros((100, 100, 3), dtype=np.uint8)\n with patch(\"cv2.imshow\") as mock_imshow, patch(\"cv2.waitKey\", return_value=ord(\"q\")) as mock_wait, patch(\n \"cv2.destroyAllWindows\"\n ) as mock_destroy:\n counter.display_output(frame)\n mock_imshow.assert_called_once()\n mock_wait.assert_called_once()\n mock_destroy.assert_called_once()", "chunk_type": "function", "name": "test_display_output_method", "file_path": "ultralytics\\tests\\test_solutions.py", "start_line": 349, "end_line": 360, "start_col": 0, "end_col": 41, "parent_name": null, "docstring": "Test that display_output triggers imshow, waitKey, and destroyAllWindows when enabled.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "os", "unittest.mock.patch", "cv2", "numpy", "pytest", "tests.MODEL", "tests.TMP", "ultralytics.solutions", "ultralytics.utils.ASSETS_URL", "ultralytics.utils.IS_RASPBERRYPI", "ultralytics.utils.checks", "ultralytics.utils.downloads.safe_download", "io", "PIL.Image", "ultralytics.solutions.solutions.SolutionResults" ], "chunk_id": "function_test_display_output_method_7e38be60" }, { "content": "from ultralytics.cfg import TASK2DATA, TASK2MODEL, TASKS", "chunk_type": "import", "name": "TASK2DATA, TASK2MODEL, TASKS", "file_path": "ultralytics\\tests\\__init__.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 56, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_TASK2DATA, TASK2MODEL, TASKS_5ae92667" }, { "content": "from ultralytics.utils import ASSETS, ROOT, WEIGHTS_DIR, checks", "chunk_type": "import", "name": "ASSETS, ROOT, WEIGHTS_DIR, checks", "file_path": "ultralytics\\tests\\__init__.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 63, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_ASSETS, ROOT, WEIGHTS_DIR, checks_22cd9f19" }, { "content": "MODEL = WEIGHTS_DIR / \"path with spaces\" / \"yolo11n.pt\" # test spaces in path", "chunk_type": "variable", "name": "MODEL", "file_path": "ultralytics\\tests\\__init__.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 55, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_MODEL_56f646ae" }, { "content": "CFG = \"yolo11n.yaml\"", "chunk_type": "variable", "name": "CFG", "file_path": "ultralytics\\tests\\__init__.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 20, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_CFG_e9a494da" }, { "content": "SOURCE = ASSETS / \"bus.jpg\"", "chunk_type": "variable", "name": "SOURCE", "file_path": "ultralytics\\tests\\__init__.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 27, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_SOURCE_a7d22237" }, { "content": "SOURCES_LIST = [ASSETS / \"bus.jpg\", ASSETS, ASSETS / \"*\", ASSETS / \"**/*.jpg\"]", "chunk_type": "variable", "name": "SOURCES_LIST", "file_path": "ultralytics\\tests\\__init__.py", "start_line": 10, "end_line": 10, "start_col": 0, "end_col": 78, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_SOURCES_LIST_6172e0fc" }, { "content": "TMP = (ROOT / \"../tests/tmp\").resolve() # temp directory for test files", "chunk_type": "variable", "name": "TMP", "file_path": "ultralytics\\tests\\__init__.py", "start_line": 11, "end_line": 11, "start_col": 0, "end_col": 39, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_TMP_2ef51fb9" }, { "content": "CUDA_IS_AVAILABLE = checks.cuda_is_available()", "chunk_type": "variable", "name": "CUDA_IS_AVAILABLE", "file_path": "ultralytics\\tests\\__init__.py", "start_line": 12, "end_line": 12, "start_col": 0, "end_col": 46, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_CUDA_IS_AVAILABLE_19cd9d72" }, { "content": "CUDA_DEVICE_COUNT = checks.cuda_device_count()", "chunk_type": "variable", "name": "CUDA_DEVICE_COUNT", "file_path": "ultralytics\\tests\\__init__.py", "start_line": 13, "end_line": 13, "start_col": 0, "end_col": 46, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_CUDA_DEVICE_COUNT_8b5e40a8" }, { "content": "TASK_MODEL_DATA = [(task, WEIGHTS_DIR / TASK2MODEL[task], TASK2DATA[task]) for task in TASKS]", "chunk_type": "variable", "name": "TASK_MODEL_DATA", "file_path": "ultralytics\\tests\\__init__.py", "start_line": 14, "end_line": 14, "start_col": 0, "end_col": 93, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_TASK_MODEL_DATA_bfe63f30" }, { "content": "MODELS = frozenset(list(TASK2MODEL.values()) + [\"yolo11n-grayscale.pt\"])", "chunk_type": "variable", "name": "MODELS", "file_path": "ultralytics\\tests\\__init__.py", "start_line": 15, "end_line": 15, "start_col": 0, "end_col": 72, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_MODELS_2793f363" }, { "content": "__all__ = (\n \"MODEL\",\n \"CFG\",\n \"SOURCE\",\n \"SOURCES_LIST\",\n \"TMP\",\n \"CUDA_IS_AVAILABLE\",\n \"CUDA_DEVICE_COUNT\",\n)", "chunk_type": "variable", "name": "__all__", "file_path": "ultralytics\\tests\\__init__.py", "start_line": 17, "end_line": 25, "start_col": 0, "end_col": 1, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable___all___a65e806d" }, { "content": "__version__ = \"8.3.167\"", "chunk_type": "variable", "name": "__version__", "file_path": "ultralytics\\ultralytics\\__init__.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 23, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable___version___0bcf3d74" }, { "content": "import os", "chunk_type": "import", "name": "os", "file_path": "ultralytics\\ultralytics\\__init__.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 9, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_os_0d195f4e" }, { "content": "from ultralytics.models import NAS, RTDETR, SAM, YOLO, YOLOE, FastSAM, YOLOWorld", "chunk_type": "import", "name": "NAS, RTDETR, SAM, YOLO, YOLOE, FastSAM, YOLOWorld", "file_path": "ultralytics\\ultralytics\\__init__.py", "start_line": 11, "end_line": 11, "start_col": 0, "end_col": 80, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_NAS, RTDETR, SAM, YOLO, YOLOE, FastSAM, YOLOWorld_ee126aed" }, { "content": "from ultralytics.utils import ASSETS, SETTINGS", "chunk_type": "import", "name": "ASSETS, SETTINGS", "file_path": "ultralytics\\ultralytics\\__init__.py", "start_line": 12, "end_line": 12, "start_col": 0, "end_col": 46, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_ASSETS, SETTINGS_103545a0" }, { "content": "from ultralytics.utils.checks import check_yolo as checks", "chunk_type": "import", "name": "check_yolo", "file_path": "ultralytics\\ultralytics\\__init__.py", "start_line": 13, "end_line": 13, "start_col": 0, "end_col": 57, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_check_yolo_b1e127d8" }, { "content": "from ultralytics.utils.downloads import download", "chunk_type": "import", "name": "download", "file_path": "ultralytics\\ultralytics\\__init__.py", "start_line": 14, "end_line": 14, "start_col": 0, "end_col": 48, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_download_f4de2c78" }, { "content": "settings = SETTINGS", "chunk_type": "variable", "name": "settings", "file_path": "ultralytics\\ultralytics\\__init__.py", "start_line": 16, "end_line": 16, "start_col": 0, "end_col": 19, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_settings_131451f8" }, { "content": "__all__ = (\n \"__version__\",\n \"ASSETS\",\n \"YOLO\",\n \"YOLOWorld\",\n \"YOLOE\",\n \"NAS\",\n \"SAM\",\n \"FastSAM\",\n \"RTDETR\",\n \"checks\",\n \"download\",\n \"settings\",\n)", "chunk_type": "variable", "name": "__all__", "file_path": "ultralytics\\ultralytics\\__init__.py", "start_line": 17, "end_line": 30, "start_col": 0, "end_col": 1, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable___all___970e534e" }, { "content": "import argparse", "chunk_type": "import", "name": "argparse", "file_path": "ultralytics\\examples\\RTDETR-ONNXRuntime-Python\\main.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 15, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_argparse_7d9c15aa" }, { "content": "import os", "chunk_type": "import", "name": "os", "file_path": "ultralytics\\examples\\RTDETR-ONNXRuntime-Python\\main.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 9, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_os_7bceb227" }, { "content": "from typing import List, Optional", "chunk_type": "import", "name": "List, Optional", "file_path": "ultralytics\\examples\\RTDETR-ONNXRuntime-Python\\main.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 33, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_List, Optional_4adae2b6" }, { "content": "import cv2", "chunk_type": "import", "name": "cv2", "file_path": "ultralytics\\examples\\RTDETR-ONNXRuntime-Python\\main.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 10, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_cv2_f8936e07" }, { "content": "import numpy as np", "chunk_type": "import", "name": "numpy", "file_path": "ultralytics\\examples\\RTDETR-ONNXRuntime-Python\\main.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_numpy_0b478dbd" }, { "content": "import onnxruntime as ort", "chunk_type": "import", "name": "onnxruntime", "file_path": "ultralytics\\examples\\RTDETR-ONNXRuntime-Python\\main.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 25, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_onnxruntime_41efc822" }, { "content": "import requests", "chunk_type": "import", "name": "requests", "file_path": "ultralytics\\examples\\RTDETR-ONNXRuntime-Python\\main.py", "start_line": 10, "end_line": 10, "start_col": 0, "end_col": 15, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_requests_801ab569" }, { "content": "import yaml", "chunk_type": "import", "name": "yaml", "file_path": "ultralytics\\examples\\RTDETR-ONNXRuntime-Python\\main.py", "start_line": 11, "end_line": 11, "start_col": 0, "end_col": 11, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_yaml_8edccfcf" }, { "content": "def download_file(url: str, local_path: str) -> str:\n \"\"\"\n Download a file from a URL to a local path.\n\n Args:\n url (str): URL of the file to download.\n local_path (str): Local path where the file will be saved.\n \"\"\"\n # Check if the local path already exists\n if os.path.exists(local_path):\n print(f\"File already exists at {local_path}. Skipping download.\")\n return local_path\n # Download the file from the URL\n print(f\"Downloading {url} to {local_path}...\")\n response = requests.get(url)\n with open(local_path, \"wb\") as f:\n f.write(response.content)\n\n return local_path", "chunk_type": "function", "name": "download_file", "file_path": "ultralytics\\examples\\RTDETR-ONNXRuntime-Python\\main.py", "start_line": 14, "end_line": 32, "start_col": 0, "end_col": 21, "parent_name": null, "docstring": "Download a file from a URL to a local path.\n\nArgs:\n url (str): URL of the file to download.\n local_path (str): Local path where the file will be saved.", "parameters": [ "url: str", "local_path: str" ], "return_type": "str", "decorators": [], "complexity_score": 2, "dependencies": [ "argparse", "os", "typing.List", "typing.Optional", "cv2", "numpy", "onnxruntime", "requests", "yaml" ], "chunk_id": "function_download_file_ccfdc984" }, { "content": "class RTDETR:\n \"\"\"\n RT-DETR (Real-Time Detection Transformer) object detection model for ONNX inference and visualization.\n\n This class implements the RT-DETR model for object detection tasks, supporting ONNX model inference and\n visualization of detection results with bounding boxes and class labels.\n\n Attributes:\n model_path (str): Path to the ONNX model file.\n img_path (str): Path to the input image.\n conf_thres (float): Confidence threshold for filtering detections.\n iou_thres (float): IoU threshold for non-maximum suppression.\n session (ort.InferenceSession): ONNX runtime inference session.\n model_input (list): Model input metadata.\n input_width (int): Width dimension required by the model.\n input_height (int): Height dimension required by the model.\n classes (List[str]): List of class names from COCO dataset.\n color_palette (np.ndarray): Random color palette for visualization.\n img (np.ndarray): Loaded input image.\n img_height (int): Height of the input image.\n img_width (int): Width of the input image.\n\n Methods:\n draw_detections: Draw bounding boxes and labels on the input image.\n preprocess: Preprocess the input image for model inference.\n bbox_cxcywh_to_xyxy: Convert bounding boxes from center format to corner format.\n postprocess: Postprocess model output to extract and visualize detections.\n main: Execute the complete object detection pipeline.\n\n Examples:\n Initialize RT-DETR detector and run inference\n >>> detector = RTDETR(\"rtdetr-l.onnx\", \"image.jpg\", conf_thres=0.5)\n >>> output_image = detector.main()\n >>> cv2.imshow(\"Detections\", output_image)\n \"\"\"\n\n def __init__(\n self,\n model_path: str,\n img_path: str,\n conf_thres: float = 0.5,\n iou_thres: float = 0.5,\n class_names: Optional[str] = None,\n ):\n \"\"\"\n Initialize the RT-DETR object detection model.\n\n Args:\n model_path (str): Path to the ONNX model file.\n img_path (str): Path to the input image.\n conf_thres (float, optional): Confidence threshold for filtering detections.\n iou_thres (float, optional): IoU threshold for non-maximum suppression.\n class_names (Optional[str], optional): Path to a YAML file containing class names.\n If None, uses COCO dataset classes.\n \"\"\"\n self.model_path = model_path\n self.img_path = img_path\n self.conf_thres = conf_thres\n self.iou_thres = iou_thres\n self.classes = class_names\n\n # Set up the ONNX runtime session with CUDA and CPU execution providers\n self.session = ort.InferenceSession(model_path, providers=[\"CUDAExecutionProvider\", \"CPUExecutionProvider\"])\n\n self.model_input = self.session.get_inputs()\n self.input_width = self.model_input[0].shape[2]\n self.input_height = self.model_input[0].shape[3]\n\n if self.classes is None:\n # Load class names from the COCO dataset YAML file\n self.classes = download_file(\n \"https://raw.githubusercontent.com/ultralytics/\"\n \"ultralytics/refs/heads/main/ultralytics/cfg/datasets/coco8.yaml\",\n \"coco8.yaml\",\n )\n\n # Parse the YAML file to get class names\n with open(self.classes) as f:\n class_data = yaml.safe_load(f)\n self.classes = list(class_data[\"names\"].values())\n\n # Ensure the classes are a list\n if not isinstance(self.classes, list):\n raise ValueError(\"Classes should be a list of class names.\")\n\n # Generate a color palette for drawing bounding boxes\n self.color_palette: np.ndarray = np.random.uniform(0, 255, size=(len(self.classes), 3))\n\n def draw_detections(self, box: np.ndarray, score: float, class_id: int) -> None:\n \"\"\"Draw bounding box and label on the input image for a detected object.\"\"\"\n # Extract the coordinates of the bounding box\n x1, y1, x2, y2 = box\n\n # Retrieve the color for the class ID\n color = self.color_palette[class_id]\n\n # Draw the bounding box on the image\n cv2.rectangle(self.img, (int(x1), int(y1)), (int(x2), int(y2)), color, 2)\n\n # Create the label text with class name and score\n label = f\"{self.classes[class_id]}: {score:.2f}\"\n\n # Calculate the dimensions of the label text\n (label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)\n\n # Calculate the position of the label text\n label_x = x1\n label_y = y1 - 10 if y1 - 10 > label_height else y1 + 10\n\n # Draw a filled rectangle as the background for the label text\n cv2.rectangle(\n self.img,\n (int(label_x), int(label_y - label_height)),\n (int(label_x + label_width), int(label_y + label_height)),\n color,\n cv2.FILLED,\n )\n\n # Draw the label text on the image\n cv2.putText(\n self.img, label, (int(label_x), int(label_y)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA\n )\n\n def preprocess(self) -> np.ndarray:\n \"\"\"\n Preprocess the input image for model inference.\n\n Loads the image, converts color space from BGR to RGB, resizes to model input dimensions, and normalizes\n pixel values to [0, 1] range.\n\n Returns:\n (np.ndarray): Preprocessed image data with shape (1, 3, H, W) ready for inference.\n \"\"\"\n # Read the input image using OpenCV\n self.img = cv2.imread(self.img_path)\n\n # Get the height and width of the input image\n self.img_height, self.img_width = self.img.shape[:2]\n\n # Convert the image color space from BGR to RGB\n img = cv2.cvtColor(self.img, cv2.COLOR_BGR2RGB)\n\n # Resize the image to match the input shape\n img = cv2.resize(img, (self.input_width, self.input_height))\n\n # Normalize the image data by dividing it by 255.0\n image_data = np.array(img) / 255.0\n\n # Transpose the image to have the channel dimension as the first dimension\n image_data = np.transpose(image_data, (2, 0, 1)) # Channel first\n\n # Expand the dimensions of the image data to match the expected input shape\n image_data = np.expand_dims(image_data, axis=0).astype(np.float32)\n\n return image_data\n\n def bbox_cxcywh_to_xyxy(self, boxes: np.ndarray) -> np.ndarray:\n \"\"\"\n Convert bounding boxes from center format to corner format.\n\n Args:\n boxes (np.ndarray): Array of shape (N, 4) where each row represents a bounding box in\n (center_x, center_y, width, height) format.\n\n Returns:\n (np.ndarray): Array of shape (N, 4) with bounding boxes in (x_min, y_min, x_max, y_max) format.\n \"\"\"\n # Calculate half width and half height of the bounding boxes\n half_width = boxes[:, 2] / 2\n half_height = boxes[:, 3] / 2\n\n # Calculate the coordinates of the bounding boxes\n x_min = boxes[:, 0] - half_width\n y_min = boxes[:, 1] - half_height\n x_max = boxes[:, 0] + half_width\n y_max = boxes[:, 1] + half_height\n\n # Return the bounding boxes in (x_min, y_min, x_max, y_max) format\n return np.column_stack((x_min, y_min, x_max, y_max))\n\n def postprocess(self, model_output: List[np.ndarray]) -> np.ndarray:\n \"\"\"\n Postprocess model output to extract and visualize detections.\n\n Applies confidence thresholding, converts bounding box format, scales coordinates to original image\n dimensions, and draws detection annotations.\n\n Args:\n model_output (List[np.ndarray]): Output tensors from the model inference.\n\n Returns:\n (np.ndarray): Annotated image with detection bounding boxes and labels.\n \"\"\"\n # Squeeze the model output to remove unnecessary dimensions\n outputs = np.squeeze(model_output[0])\n\n # Extract bounding boxes and scores from the model output\n boxes = outputs[:, :4]\n scores = outputs[:, 4:]\n\n # Get the class labels and scores for each detection\n labels = np.argmax(scores, axis=1)\n scores = np.max(scores, axis=1)\n\n # Apply confidence threshold to filter out low-confidence detections\n mask = scores > self.conf_thres\n boxes, scores, labels = boxes[mask], scores[mask], labels[mask]\n\n # Convert bounding boxes to (x_min, y_min, x_max, y_max) format\n boxes = self.bbox_cxcywh_to_xyxy(boxes)\n\n # Scale bounding boxes to match the original image dimensions\n boxes[:, 0::2] *= self.img_width\n boxes[:, 1::2] *= self.img_height\n\n # Draw detections on the image\n for box, score, label in zip(boxes, scores, labels):\n self.draw_detections(box, score, label)\n\n return self.img\n\n def main(self) -> np.ndarray:\n \"\"\"\n Execute the complete object detection pipeline on the input image.\n\n Performs preprocessing, ONNX model inference, and postprocessing to generate annotated detection results.\n\n Returns:\n (np.ndarray): Output image with detection annotations including bounding boxes and class labels.\n \"\"\"\n # Preprocess the image for model input\n image_data = self.preprocess()\n\n # Run the model inference\n model_output = self.session.run(None, {self.model_input[0].name: image_data})\n\n # Process and return the model output\n return self.postprocess(model_output)", "chunk_type": "class", "name": "RTDETR", "file_path": "ultralytics\\examples\\RTDETR-ONNXRuntime-Python\\main.py", "start_line": 35, "end_line": 272, "start_col": 0, "end_col": 45, "parent_name": null, "docstring": "RT-DETR (Real-Time Detection Transformer) object detection model for ONNX inference and visualization.\n\nThis class implements the RT-DETR model for object detection tasks, supporting ONNX model inference and\nvisualization of detection results with bounding boxes and class labels.\n\nAttributes:\n model_path (str): Path to the ONNX model file.\n img_path (str): Path to the input image.\n conf_thres (float): Confidence threshold for filtering detections.\n iou_thres (float): IoU threshold for non-maximum suppression.\n session (ort.InferenceSession): ONNX runtime inference session.\n model_input (list): Model input metadata.\n input_width (int): Width dimension required by the model.\n input_height (int): Height dimension required by the model.\n classes (List[str]): List of class names from COCO dataset.\n color_palette (np.ndarray): Random color palette for visualization.\n img (np.ndarray): Loaded input image.\n img_height (int): Height of the input image.\n img_width (int): Width of the input image.\n\nMethods:\n draw_detections: Draw bounding boxes and labels on the input image.\n preprocess: Preprocess the input image for model inference.\n bbox_cxcywh_to_xyxy: Convert bounding boxes from center format to corner format.\n postprocess: Postprocess model output to extract and visualize detections.\n main: Execute the complete object detection pipeline.\n\nExamples:\n Initialize RT-DETR detector and run inference\n >>> detector = RTDETR(\"rtdetr-l.onnx\", \"image.jpg\", conf_thres=0.5)\n >>> output_image = detector.main()\n >>> cv2.imshow(\"Detections\", output_image)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "argparse", "os", "typing.List", "typing.Optional", "cv2", "numpy", "onnxruntime", "requests", "yaml" ], "chunk_id": "class_RTDETR_f8d63bb2" }, { "content": "import time", "chunk_type": "import", "name": "time", "file_path": "ultralytics\\examples\\YOLO-Interactive-Tracking-UI\\interactive_tracker.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 11, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_time_b7c6dcd3" }, { "content": "from typing import Tuple", "chunk_type": "import", "name": "Tuple", "file_path": "ultralytics\\examples\\YOLO-Interactive-Tracking-UI\\interactive_tracker.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Tuple_da6be15f" }, { "content": "import cv2", "chunk_type": "import", "name": "cv2", "file_path": "ultralytics\\examples\\YOLO-Interactive-Tracking-UI\\interactive_tracker.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 10, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_cv2_45ac09dd" }, { "content": "from ultralytics import YOLO", "chunk_type": "import", "name": "YOLO", "file_path": "ultralytics\\examples\\YOLO-Interactive-Tracking-UI\\interactive_tracker.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 28, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_YOLO_26f931fd" }, { "content": "from ultralytics.utils import LOGGER", "chunk_type": "import", "name": "LOGGER", "file_path": "ultralytics\\examples\\YOLO-Interactive-Tracking-UI\\interactive_tracker.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 36, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_LOGGER_67ec3860" }, { "content": "from ultralytics.utils.plotting import Annotator, colors", "chunk_type": "import", "name": "Annotator, colors", "file_path": "ultralytics\\examples\\YOLO-Interactive-Tracking-UI\\interactive_tracker.py", "start_line": 10, "end_line": 10, "start_col": 0, "end_col": 56, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Annotator, colors_cde10634" }, { "content": "enable_gpu = False # Set True if running with CUDA", "chunk_type": "variable", "name": "enable_gpu", "file_path": "ultralytics\\examples\\YOLO-Interactive-Tracking-UI\\interactive_tracker.py", "start_line": 12, "end_line": 12, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_enable_gpu_d52c5f12" }, { "content": "model_file = \"yolo11s.pt\" # Path to model file", "chunk_type": "variable", "name": "model_file", "file_path": "ultralytics\\examples\\YOLO-Interactive-Tracking-UI\\interactive_tracker.py", "start_line": 13, "end_line": 13, "start_col": 0, "end_col": 25, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_model_file_18c5d015" }, { "content": "show_fps = True # If True, shows current FPS in top-left corner", "chunk_type": "variable", "name": "show_fps", "file_path": "ultralytics\\examples\\YOLO-Interactive-Tracking-UI\\interactive_tracker.py", "start_line": 14, "end_line": 14, "start_col": 0, "end_col": 15, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_show_fps_67762cba" }, { "content": "show_conf = False # Display or hide the confidence score", "chunk_type": "variable", "name": "show_conf", "file_path": "ultralytics\\examples\\YOLO-Interactive-Tracking-UI\\interactive_tracker.py", "start_line": 15, "end_line": 15, "start_col": 0, "end_col": 17, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_show_conf_162d1632" }, { "content": "save_video = True # Set True to save output video", "chunk_type": "variable", "name": "save_video", "file_path": "ultralytics\\examples\\YOLO-Interactive-Tracking-UI\\interactive_tracker.py", "start_line": 16, "end_line": 16, "start_col": 0, "end_col": 17, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_save_video_f4030b1d" }, { "content": "video_output_path = \"interactive_tracker_output.avi\" # Output video file name", "chunk_type": "variable", "name": "video_output_path", "file_path": "ultralytics\\examples\\YOLO-Interactive-Tracking-UI\\interactive_tracker.py", "start_line": 17, "end_line": 17, "start_col": 0, "end_col": 52, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_video_output_path_8d05636c" }, { "content": "conf = 0.3 # Min confidence for object detection (lower = more detections, possibly more false positives)", "chunk_type": "variable", "name": "conf", "file_path": "ultralytics\\examples\\YOLO-Interactive-Tracking-UI\\interactive_tracker.py", "start_line": 20, "end_line": 20, "start_col": 0, "end_col": 10, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_conf_81f29fc5" }, { "content": "iou = 0.3 # IoU threshold for NMS (higher = less overlap allowed)", "chunk_type": "variable", "name": "iou", "file_path": "ultralytics\\examples\\YOLO-Interactive-Tracking-UI\\interactive_tracker.py", "start_line": 21, "end_line": 21, "start_col": 0, "end_col": 9, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_iou_6d6f09bb" }, { "content": "max_det = 20 # Maximum objects per image (increase for crowded scenes)", "chunk_type": "variable", "name": "max_det", "file_path": "ultralytics\\examples\\YOLO-Interactive-Tracking-UI\\interactive_tracker.py", "start_line": 22, "end_line": 22, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_max_det_d1e72903" }, { "content": "tracker = \"bytetrack.yaml\" # Tracker config: 'bytetrack.yaml', 'botsort.yaml', etc.", "chunk_type": "variable", "name": "tracker", "file_path": "ultralytics\\examples\\YOLO-Interactive-Tracking-UI\\interactive_tracker.py", "start_line": 24, "end_line": 24, "start_col": 0, "end_col": 26, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_tracker_4ae9199a" }, { "content": "track_args = {\n \"persist\": True, # Keep frames history as a stream for continuous tracking\n \"verbose\": False, # Print debug info from tracker\n}", "chunk_type": "variable", "name": "track_args", "file_path": "ultralytics\\examples\\YOLO-Interactive-Tracking-UI\\interactive_tracker.py", "start_line": 25, "end_line": 28, "start_col": 0, "end_col": 1, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_track_args_133521d1" }, { "content": "window_name = \"Ultralytics YOLO Interactive Tracking\" # Output window name", "chunk_type": "variable", "name": "window_name", "file_path": "ultralytics\\examples\\YOLO-Interactive-Tracking-UI\\interactive_tracker.py", "start_line": 30, "end_line": 30, "start_col": 0, "end_col": 53, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_window_name_3c81afe4" }, { "content": "classes = model.names # Store model class names", "chunk_type": "variable", "name": "classes", "file_path": "ultralytics\\examples\\YOLO-Interactive-Tracking-UI\\interactive_tracker.py", "start_line": 41, "end_line": 41, "start_col": 0, "end_col": 21, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_classes_417b0e3b" }, { "content": "cap = cv2.VideoCapture(0) # Replace with video path if needed", "chunk_type": "variable", "name": "cap", "file_path": "ultralytics\\examples\\YOLO-Interactive-Tracking-UI\\interactive_tracker.py", "start_line": 43, "end_line": 43, "start_col": 0, "end_col": 25, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_cap_8944292b" }, { "content": "vw = None", "chunk_type": "variable", "name": "vw", "file_path": "ultralytics\\examples\\YOLO-Interactive-Tracking-UI\\interactive_tracker.py", "start_line": 46, "end_line": 46, "start_col": 0, "end_col": 9, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_vw_4dab833e" }, { "content": "selected_object_id = None", "chunk_type": "variable", "name": "selected_object_id", "file_path": "ultralytics\\examples\\YOLO-Interactive-Tracking-UI\\interactive_tracker.py", "start_line": 51, "end_line": 51, "start_col": 0, "end_col": 25, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_selected_object_id_f410221a" }, { "content": "selected_bbox = None", "chunk_type": "variable", "name": "selected_bbox", "file_path": "ultralytics\\examples\\YOLO-Interactive-Tracking-UI\\interactive_tracker.py", "start_line": 52, "end_line": 52, "start_col": 0, "end_col": 20, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_selected_bbox_ad990615" }, { "content": "selected_center = None", "chunk_type": "variable", "name": "selected_center", "file_path": "ultralytics\\examples\\YOLO-Interactive-Tracking-UI\\interactive_tracker.py", "start_line": 53, "end_line": 53, "start_col": 0, "end_col": 22, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_selected_center_27af1610" }, { "content": "def get_center(x1: int, y1: int, x2: int, y2: int) -> Tuple[int, int]:\n \"\"\"\n Calculate the center point of a bounding box.\n\n Args:\n x1 (int): Top-left X coordinate.\n y1 (int): Top-left Y coordinate.\n x2 (int): Bottom-right X coordinate.\n y2 (int): Bottom-right Y coordinate.\n\n Returns:\n center_x (int): X-coordinate of the center point.\n center_y (int): Y-coordinate of the center point.\n \"\"\"\n return (x1 + x2) // 2, (y1 + y2) // 2", "chunk_type": "function", "name": "get_center", "file_path": "ultralytics\\examples\\YOLO-Interactive-Tracking-UI\\interactive_tracker.py", "start_line": 56, "end_line": 70, "start_col": 0, "end_col": 41, "parent_name": null, "docstring": "Calculate the center point of a bounding box.\n\nArgs:\n x1 (int): Top-left X coordinate.\n y1 (int): Top-left Y coordinate.\n x2 (int): Bottom-right X coordinate.\n y2 (int): Bottom-right Y coordinate.\n\nReturns:\n center_x (int): X-coordinate of the center point.\n center_y (int): Y-coordinate of the center point.", "parameters": [ "x1: int", "y1: int", "x2: int", "y2: int" ], "return_type": "Tuple[int, int]", "decorators": [], "complexity_score": 1, "dependencies": [ "time", "typing.Tuple", "cv2", "ultralytics.YOLO", "ultralytics.utils.LOGGER", "ultralytics.utils.plotting.Annotator", "ultralytics.utils.plotting.colors" ], "chunk_id": "function_get_center_9f437c0d" }, { "content": "def extend_line_from_edge(mid_x: int, mid_y: int, direction: str, img_shape: Tuple[int, int, int]) -> Tuple[int, int]:\n \"\"\"\n Calculate the endpoint to extend a line from the center toward an image edge.\n\n Args:\n mid_x (int): X-coordinate of the midpoint.\n mid_y (int): Y-coordinate of the midpoint.\n direction (str): Direction to extend ('left', 'right', 'up', 'down').\n img_shape (Tuple[int, int, int]): Image shape in (height, width, channels).\n\n Returns:\n end_x (int): X-coordinate of the endpoint.\n end_y (int): Y-coordinate of the endpoint.\n \"\"\"\n h, w = img_shape[:2]\n if direction == \"left\":\n return 0, mid_y\n if direction == \"right\":\n return w - 1, mid_y\n if direction == \"up\":\n return mid_x, 0\n if direction == \"down\":\n return mid_x, h - 1\n return mid_x, mid_y", "chunk_type": "function", "name": "extend_line_from_edge", "file_path": "ultralytics\\examples\\YOLO-Interactive-Tracking-UI\\interactive_tracker.py", "start_line": 73, "end_line": 96, "start_col": 0, "end_col": 23, "parent_name": null, "docstring": "Calculate the endpoint to extend a line from the center toward an image edge.\n\nArgs:\n mid_x (int): X-coordinate of the midpoint.\n mid_y (int): Y-coordinate of the midpoint.\n direction (str): Direction to extend ('left', 'right', 'up', 'down').\n img_shape (Tuple[int, int, int]): Image shape in (height, width, channels).\n\nReturns:\n end_x (int): X-coordinate of the endpoint.\n end_y (int): Y-coordinate of the endpoint.", "parameters": [ "mid_x: int", "mid_y: int", "direction: str", "img_shape: Tuple[int, int, int]" ], "return_type": "Tuple[int, int]", "decorators": [], "complexity_score": 5, "dependencies": [ "time", "typing.Tuple", "cv2", "ultralytics.YOLO", "ultralytics.utils.LOGGER", "ultralytics.utils.plotting.Annotator", "ultralytics.utils.plotting.colors" ], "chunk_id": "function_extend_line_from_edge_e7b204d2" }, { "content": "def draw_tracking_scope(im, bbox: tuple, color: tuple) -> None:\n \"\"\"\n Draw tracking scope lines extending from the bounding box to image edges.\n\n Args:\n im (np.ndarray): Image array to draw on.\n bbox (tuple): Bounding box coordinates (x1, y1, x2, y2).\n color (tuple): Color in BGR format for drawing.\n \"\"\"\n x1, y1, x2, y2 = bbox\n mid_top = ((x1 + x2) // 2, y1)\n mid_bottom = ((x1 + x2) // 2, y2)\n mid_left = (x1, (y1 + y2) // 2)\n mid_right = (x2, (y1 + y2) // 2)\n cv2.line(im, mid_top, extend_line_from_edge(*mid_top, \"up\", im.shape), color, 2)\n cv2.line(im, mid_bottom, extend_line_from_edge(*mid_bottom, \"down\", im.shape), color, 2)\n cv2.line(im, mid_left, extend_line_from_edge(*mid_left, \"left\", im.shape), color, 2)\n cv2.line(im, mid_right, extend_line_from_edge(*mid_right, \"right\", im.shape), color, 2)", "chunk_type": "function", "name": "draw_tracking_scope", "file_path": "ultralytics\\examples\\YOLO-Interactive-Tracking-UI\\interactive_tracker.py", "start_line": 99, "end_line": 116, "start_col": 0, "end_col": 91, "parent_name": null, "docstring": "Draw tracking scope lines extending from the bounding box to image edges.\n\nArgs:\n im (np.ndarray): Image array to draw on.\n bbox (tuple): Bounding box coordinates (x1, y1, x2, y2).\n color (tuple): Color in BGR format for drawing.", "parameters": [ "im", "bbox: tuple", "color: tuple" ], "return_type": "None", "decorators": [], "complexity_score": 1, "dependencies": [ "time", "typing.Tuple", "cv2", "ultralytics.YOLO", "ultralytics.utils.LOGGER", "ultralytics.utils.plotting.Annotator", "ultralytics.utils.plotting.colors" ], "chunk_id": "function_draw_tracking_scope_ce5fb387" }, { "content": "def click_event(event: int, x: int, y: int, flags: int, param) -> None:\n \"\"\"\n Handle mouse click events to select an object for focused tracking.\n\n Args:\n event (int): OpenCV mouse event type.\n x (int): X-coordinate of the mouse event.\n y (int): Y-coordinate of the mouse event.\n flags (int): Any relevant flags passed by OpenCV.\n param (Any): Additional parameters (not used).\n \"\"\"\n global selected_object_id\n if event == cv2.EVENT_LBUTTONDOWN:\n detections = results[0].boxes.data if results[0].boxes is not None else []\n if detections is not None:\n min_area = float(\"inf\")\n best_match = None\n for track in detections:\n track = track.tolist()\n if len(track) >= 6:\n x1, y1, x2, y2 = map(int, track[:4])\n if x1 <= x <= x2 and y1 <= y <= y2:\n area = (x2 - x1) * (y2 - y1)\n if area < min_area:\n class_id = int(track[-1])\n track_id = int(track[4]) if len(track) == 7 else -1\n min_area = area\n best_match = (track_id, model.names[class_id])\n if best_match:\n selected_object_id, label = best_match\n print(f\"🔵 TRACKING STARTED: {label} (ID {selected_object_id})\")", "chunk_type": "function", "name": "click_event", "file_path": "ultralytics\\examples\\YOLO-Interactive-Tracking-UI\\interactive_tracker.py", "start_line": 119, "end_line": 149, "start_col": 0, "end_col": 82, "parent_name": null, "docstring": "Handle mouse click events to select an object for focused tracking.\n\nArgs:\n event (int): OpenCV mouse event type.\n x (int): X-coordinate of the mouse event.\n y (int): Y-coordinate of the mouse event.\n flags (int): Any relevant flags passed by OpenCV.\n param (Any): Additional parameters (not used).", "parameters": [ "event: int", "x: int", "y: int", "flags: int", "param" ], "return_type": "None", "decorators": [], "complexity_score": 8, "dependencies": [ "time", "typing.Tuple", "cv2", "ultralytics.YOLO", "ultralytics.utils.LOGGER", "ultralytics.utils.plotting.Annotator", "ultralytics.utils.plotting.colors" ], "chunk_id": "function_click_event_5600e9c3" }, { "content": "import argparse", "chunk_type": "import", "name": "argparse", "file_path": "ultralytics\\examples\\YOLOv8-Action-Recognition\\action_recognition.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 15, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_argparse_f1b8345f" }, { "content": "import time", "chunk_type": "import", "name": "time", "file_path": "ultralytics\\examples\\YOLOv8-Action-Recognition\\action_recognition.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 11, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_time_21463260" }, { "content": "from collections import defaultdict", "chunk_type": "import", "name": "defaultdict", "file_path": "ultralytics\\examples\\YOLOv8-Action-Recognition\\action_recognition.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 35, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_defaultdict_7bdf3b38" }, { "content": "from typing import List, Optional, Tuple", "chunk_type": "import", "name": "List, Optional, Tuple", "file_path": "ultralytics\\examples\\YOLOv8-Action-Recognition\\action_recognition.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 40, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_List, Optional, Tuple_6425d028" }, { "content": "from urllib.parse import urlparse", "chunk_type": "import", "name": "urlparse", "file_path": "ultralytics\\examples\\YOLOv8-Action-Recognition\\action_recognition.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 33, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_urlparse_a0202dfa" }, { "content": "import cv2", "chunk_type": "import", "name": "cv2", "file_path": "ultralytics\\examples\\YOLOv8-Action-Recognition\\action_recognition.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 10, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_cv2_a68975f4" }, { "content": "import numpy as np", "chunk_type": "import", "name": "numpy", "file_path": "ultralytics\\examples\\YOLOv8-Action-Recognition\\action_recognition.py", "start_line": 10, "end_line": 10, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_numpy_b05149fd" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\examples\\YOLOv8-Action-Recognition\\action_recognition.py", "start_line": 11, "end_line": 11, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_39effd4a" }, { "content": "from transformers import AutoModel, AutoProcessor", "chunk_type": "import", "name": "AutoModel, AutoProcessor", "file_path": "ultralytics\\examples\\YOLOv8-Action-Recognition\\action_recognition.py", "start_line": 12, "end_line": 12, "start_col": 0, "end_col": 49, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_AutoModel, AutoProcessor_0716673b" }, { "content": "from ultralytics import YOLO", "chunk_type": "import", "name": "YOLO", "file_path": "ultralytics\\examples\\YOLOv8-Action-Recognition\\action_recognition.py", "start_line": 14, "end_line": 14, "start_col": 0, "end_col": 28, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_YOLO_7e654a56" }, { "content": "from ultralytics.data.loaders import get_best_youtube_url", "chunk_type": "import", "name": "get_best_youtube_url", "file_path": "ultralytics\\examples\\YOLOv8-Action-Recognition\\action_recognition.py", "start_line": 15, "end_line": 15, "start_col": 0, "end_col": 57, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_get_best_youtube_url_2f7f8ad9" }, { "content": "from ultralytics.utils.plotting import Annotator", "chunk_type": "import", "name": "Annotator", "file_path": "ultralytics\\examples\\YOLOv8-Action-Recognition\\action_recognition.py", "start_line": 16, "end_line": 16, "start_col": 0, "end_col": 48, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Annotator_10454841" }, { "content": "from ultralytics.utils.torch_utils import select_device", "chunk_type": "import", "name": "select_device", "file_path": "ultralytics\\examples\\YOLOv8-Action-Recognition\\action_recognition.py", "start_line": 17, "end_line": 17, "start_col": 0, "end_col": 55, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_select_device_46f4f6a2" }, { "content": "class TorchVisionVideoClassifier:\n \"\"\"\n Video classifier using pretrained TorchVision models for action recognition.\n\n This class provides an interface for video classification using various pretrained models from TorchVision's\n video model collection, supporting models like S3D, R3D, Swin3D, and MViT architectures.\n\n Attributes:\n model (torch.nn.Module): The loaded TorchVision model for video classification.\n weights (torchvision.models.video.Weights): The weights used for the model.\n device (torch.device): The device on which the model is loaded.\n\n Methods:\n available_model_names: Returns a list of available model names.\n preprocess_crops_for_video_cls: Preprocesses crops for video classification.\n __call__: Performs inference on the given sequences.\n postprocess: Postprocesses the model's output.\n\n Examples:\n >>> classifier = TorchVisionVideoClassifier(\"s3d\", device=\"cpu\")\n >>> crops = [np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8) for _ in range(8)]\n >>> tensor = classifier.preprocess_crops_for_video_cls(crops)\n >>> outputs = classifier(tensor)\n >>> labels, confidences = classifier.postprocess(outputs)\n\n References:\n https://pytorch.org/vision/stable/\n \"\"\"\n\n from torchvision.models.video import (\n MViT_V1_B_Weights,\n MViT_V2_S_Weights,\n R3D_18_Weights,\n S3D_Weights,\n Swin3D_B_Weights,\n Swin3D_T_Weights,\n mvit_v1_b,\n mvit_v2_s,\n r3d_18,\n s3d,\n swin3d_b,\n swin3d_t,\n )\n\n model_name_to_model_and_weights = {\n \"s3d\": (s3d, S3D_Weights.DEFAULT),\n \"r3d_18\": (r3d_18, R3D_18_Weights.DEFAULT),\n \"swin3d_t\": (swin3d_t, Swin3D_T_Weights.DEFAULT),\n \"swin3d_b\": (swin3d_b, Swin3D_B_Weights.DEFAULT),\n \"mvit_v1_b\": (mvit_v1_b, MViT_V1_B_Weights.DEFAULT),\n \"mvit_v2_s\": (mvit_v2_s, MViT_V2_S_Weights.DEFAULT),\n }\n\n def __init__(self, model_name: str, device: str | torch.device = \"\"):\n \"\"\"\n Initialize the VideoClassifier with the specified model name and device.\n\n Args:\n model_name (str): The name of the model to use. Must be one of the available models.\n device (str | torch.device): The device to run the model on.\n \"\"\"\n if model_name not in self.model_name_to_model_and_weights:\n raise ValueError(f\"Invalid model name '{model_name}'. Available models: {self.available_model_names()}\")\n model, self.weights = self.model_name_to_model_and_weights[model_name]\n self.device = select_device(device)\n self.model = model(weights=self.weights).to(self.device).eval()\n\n @staticmethod\n def available_model_names() -> List[str]:\n \"\"\"\n Get the list of available model names.\n\n Returns:\n (List[str]): List of available model names that can be used with this classifier.\n \"\"\"\n return list(TorchVisionVideoClassifier.model_name_to_model_and_weights.keys())\n\n def preprocess_crops_for_video_cls(self, crops: List[np.ndarray], input_size: List[int] = None) -> torch.Tensor:\n \"\"\"\n Preprocess a list of crops for video classification.\n\n Args:\n crops (List[np.ndarray]): List of crops to preprocess. Each crop should have dimensions (H, W, C).\n input_size (List[int], optional): The target input size for the model.\n\n Returns:\n (torch.Tensor): Preprocessed crops as a tensor with dimensions (1, T, C, H, W).\n \"\"\"\n if input_size is None:\n input_size = [224, 224]\n from torchvision.transforms import v2\n\n transform = v2.Compose(\n [\n v2.ToDtype(torch.float32, scale=True),\n v2.Resize(input_size, antialias=True),\n v2.Normalize(mean=self.weights.transforms().mean, std=self.weights.transforms().std),\n ]\n )\n\n processed_crops = [transform(torch.from_numpy(crop).permute(2, 0, 1)) for crop in crops]\n return torch.stack(processed_crops).unsqueeze(0).permute(0, 2, 1, 3, 4).to(self.device)\n\n def __call__(self, sequences: torch.Tensor) -> torch.Tensor:\n \"\"\"\n Perform inference on the given sequences.\n\n Args:\n sequences (torch.Tensor): The input sequences for the model with dimensions (B, T, C, H, W) for batched\n video frames or (T, C, H, W) for single video frames.\n\n Returns:\n (torch.Tensor): The model's output logits.\n \"\"\"\n with torch.inference_mode():\n return self.model(sequences)\n\n def postprocess(self, outputs: torch.Tensor) -> Tuple[List[str], List[float]]:\n \"\"\"\n Postprocess the model's batch output.\n\n Args:\n outputs (torch.Tensor): The model's output logits.\n\n Returns:\n pred_labels (List[str]): The predicted labels.\n pred_confs (List[float]): The predicted confidences.\n \"\"\"\n pred_labels = []\n pred_confs = []\n for output in outputs:\n pred_class = output.argmax(0).item()\n pred_label = self.weights.meta[\"categories\"][pred_class]\n pred_labels.append(pred_label)\n pred_conf = output.softmax(0)[pred_class].item()\n pred_confs.append(pred_conf)\n\n return pred_labels, pred_confs", "chunk_type": "class", "name": "TorchVisionVideoClassifier", "file_path": "ultralytics\\examples\\YOLOv8-Action-Recognition\\action_recognition.py", "start_line": 20, "end_line": 157, "start_col": 0, "end_col": 38, "parent_name": null, "docstring": "Video classifier using pretrained TorchVision models for action recognition.\n\nThis class provides an interface for video classification using various pretrained models from TorchVision's\nvideo model collection, supporting models like S3D, R3D, Swin3D, and MViT architectures.\n\nAttributes:\n model (torch.nn.Module): The loaded TorchVision model for video classification.\n weights (torchvision.models.video.Weights): The weights used for the model.\n device (torch.device): The device on which the model is loaded.\n\nMethods:\n available_model_names: Returns a list of available model names.\n preprocess_crops_for_video_cls: Preprocesses crops for video classification.\n __call__: Performs inference on the given sequences.\n postprocess: Postprocesses the model's output.\n\nExamples:\n >>> classifier = TorchVisionVideoClassifier(\"s3d\", device=\"cpu\")\n >>> crops = [np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8) for _ in range(8)]\n >>> tensor = classifier.preprocess_crops_for_video_cls(crops)\n >>> outputs = classifier(tensor)\n >>> labels, confidences = classifier.postprocess(outputs)\n\nReferences:\n https://pytorch.org/vision/stable/", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "argparse", "time", "collections.defaultdict", "typing.List", "typing.Optional", "typing.Tuple", "urllib.parse.urlparse", "cv2", "numpy", "torch", "transformers.AutoModel", "transformers.AutoProcessor", "ultralytics.YOLO", "ultralytics.data.loaders.get_best_youtube_url", "ultralytics.utils.plotting.Annotator", "ultralytics.utils.torch_utils.select_device", "torchvision.models.video.MViT_V1_B_Weights", "torchvision.models.video.MViT_V2_S_Weights", "torchvision.models.video.R3D_18_Weights", "torchvision.models.video.S3D_Weights", "torchvision.models.video.Swin3D_B_Weights", "torchvision.models.video.Swin3D_T_Weights", "torchvision.models.video.mvit_v1_b", "torchvision.models.video.mvit_v2_s", "torchvision.models.video.r3d_18", "torchvision.models.video.s3d", "torchvision.models.video.swin3d_b", "torchvision.models.video.swin3d_t", "torchvision.transforms.v2", "torchvision.transforms" ], "chunk_id": "class_TorchVisionVideoClassifier_7cd8c6de" }, { "content": "class HuggingFaceVideoClassifier:\n \"\"\"\n Zero-shot video classifier using Hugging Face transformer models.\n\n This class provides an interface for zero-shot video classification using Hugging Face models, supporting\n custom label sets and various transformer architectures for video understanding.\n\n Attributes:\n fp16 (bool): Whether to use FP16 for inference.\n labels (List[str]): List of labels for zero-shot classification.\n device (torch.device): The device on which the model is loaded.\n processor (transformers.AutoProcessor): The processor for the model.\n model (transformers.AutoModel): The loaded Hugging Face model.\n\n Methods:\n preprocess_crops_for_video_cls: Preprocesses crops for video classification.\n __call__: Performs inference on the given sequences.\n postprocess: Postprocesses the model's output.\n\n Examples:\n >>> labels = [\"walking\", \"running\", \"dancing\"]\n >>> classifier = HuggingFaceVideoClassifier(labels, device=\"cpu\")\n >>> crops = [np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8) for _ in range(8)]\n >>> tensor = classifier.preprocess_crops_for_video_cls(crops)\n >>> outputs = classifier(tensor)\n >>> labels, confidences = classifier.postprocess(outputs)\n \"\"\"\n\n def __init__(\n self,\n labels: List[str],\n model_name: str = \"microsoft/xclip-base-patch16-zero-shot\",\n device: str | torch.device = \"\",\n fp16: bool = False,\n ):\n \"\"\"\n Initialize the HuggingFaceVideoClassifier with the specified model name.\n\n Args:\n labels (List[str]): List of labels for zero-shot classification.\n model_name (str): The name of the model to use.\n device (str | torch.device): The device to run the model on.\n fp16 (bool): Whether to use FP16 for inference.\n \"\"\"\n self.fp16 = fp16\n self.labels = labels\n self.device = select_device(device)\n self.processor = AutoProcessor.from_pretrained(model_name)\n model = AutoModel.from_pretrained(model_name).to(self.device)\n if fp16:\n model = model.half()\n self.model = model.eval()\n\n def preprocess_crops_for_video_cls(self, crops: List[np.ndarray], input_size: List[int] = None) -> torch.Tensor:\n \"\"\"\n Preprocess a list of crops for video classification.\n\n Args:\n crops (List[np.ndarray]): List of crops to preprocess. Each crop should have dimensions (H, W, C).\n input_size (List[int], optional): The target input size for the model.\n\n Returns:\n (torch.Tensor): Preprocessed crops as a tensor with dimensions (1, T, C, H, W).\n \"\"\"\n if input_size is None:\n input_size = [224, 224]\n from torchvision import transforms\n\n transform = transforms.Compose(\n [\n transforms.Lambda(lambda x: x.float() / 255.0),\n transforms.Resize(input_size),\n transforms.Normalize(\n mean=self.processor.image_processor.image_mean, std=self.processor.image_processor.image_std\n ),\n ]\n )\n\n processed_crops = [transform(torch.from_numpy(crop).permute(2, 0, 1)) for crop in crops] # (T, C, H, W)\n output = torch.stack(processed_crops).unsqueeze(0).to(self.device) # (1, T, C, H, W)\n if self.fp16:\n output = output.half()\n return output\n\n def __call__(self, sequences: torch.Tensor) -> torch.Tensor:\n \"\"\"\n Perform inference on the given sequences.\n\n Args:\n sequences (torch.Tensor): The input sequences for the model. Batched video frames with shape (B, T, H, W, C).\n\n Returns:\n (torch.Tensor): The model's output logits.\n \"\"\"\n input_ids = self.processor(text=self.labels, return_tensors=\"pt\", padding=True)[\"input_ids\"].to(self.device)\n\n inputs = {\"pixel_values\": sequences, \"input_ids\": input_ids}\n\n with torch.inference_mode():\n outputs = self.model(**inputs)\n\n return outputs.logits_per_video\n\n def postprocess(self, outputs: torch.Tensor) -> Tuple[List[List[str]], List[List[float]]]:\n \"\"\"\n Postprocess the model's batch output.\n\n Args:\n outputs (torch.Tensor): The model's output logits.\n\n Returns:\n pred_labels (List[List[str]]): The predicted top2 labels for each sample.\n pred_confs (List[List[float]]): The predicted top2 confidences for each sample.\n \"\"\"\n pred_labels = []\n pred_confs = []\n\n with torch.no_grad():\n logits_per_video = outputs # Assuming outputs is already the logits tensor\n probs = logits_per_video.softmax(dim=-1) # Use softmax to convert logits to probabilities\n\n for prob in probs:\n top2_indices = prob.topk(2).indices.tolist()\n top2_labels = [self.labels[idx] for idx in top2_indices]\n top2_confs = prob[top2_indices].tolist()\n pred_labels.append(top2_labels)\n pred_confs.append(top2_confs)\n\n return pred_labels, pred_confs", "chunk_type": "class", "name": "HuggingFaceVideoClassifier", "file_path": "ultralytics\\examples\\YOLOv8-Action-Recognition\\action_recognition.py", "start_line": 160, "end_line": 288, "start_col": 0, "end_col": 38, "parent_name": null, "docstring": "Zero-shot video classifier using Hugging Face transformer models.\n\nThis class provides an interface for zero-shot video classification using Hugging Face models, supporting\ncustom label sets and various transformer architectures for video understanding.\n\nAttributes:\n fp16 (bool): Whether to use FP16 for inference.\n labels (List[str]): List of labels for zero-shot classification.\n device (torch.device): The device on which the model is loaded.\n processor (transformers.AutoProcessor): The processor for the model.\n model (transformers.AutoModel): The loaded Hugging Face model.\n\nMethods:\n preprocess_crops_for_video_cls: Preprocesses crops for video classification.\n __call__: Performs inference on the given sequences.\n postprocess: Postprocesses the model's output.\n\nExamples:\n >>> labels = [\"walking\", \"running\", \"dancing\"]\n >>> classifier = HuggingFaceVideoClassifier(labels, device=\"cpu\")\n >>> crops = [np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8) for _ in range(8)]\n >>> tensor = classifier.preprocess_crops_for_video_cls(crops)\n >>> outputs = classifier(tensor)\n >>> labels, confidences = classifier.postprocess(outputs)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "argparse", "time", "collections.defaultdict", "typing.List", "typing.Optional", "typing.Tuple", "urllib.parse.urlparse", "cv2", "numpy", "torch", "transformers.AutoModel", "transformers.AutoProcessor", "ultralytics.YOLO", "ultralytics.data.loaders.get_best_youtube_url", "ultralytics.utils.plotting.Annotator", "ultralytics.utils.torch_utils.select_device", "torchvision.models.video.MViT_V1_B_Weights", "torchvision.models.video.MViT_V2_S_Weights", "torchvision.models.video.R3D_18_Weights", "torchvision.models.video.S3D_Weights", "torchvision.models.video.Swin3D_B_Weights", "torchvision.models.video.Swin3D_T_Weights", "torchvision.models.video.mvit_v1_b", "torchvision.models.video.mvit_v2_s", "torchvision.models.video.r3d_18", "torchvision.models.video.s3d", "torchvision.models.video.swin3d_b", "torchvision.models.video.swin3d_t", "torchvision.transforms.v2", "torchvision.transforms" ], "chunk_id": "class_HuggingFaceVideoClassifier_0ba38bd4" }, { "content": "def crop_and_pad(frame: np.ndarray, box: List[float], margin_percent: int) -> np.ndarray:\n \"\"\"\n Crop box with margin and take square crop from frame.\n\n Args:\n frame (np.ndarray): The input frame to crop from.\n box (List[float]): The bounding box coordinates [x1, y1, x2, y2].\n margin_percent (int): The percentage of margin to add around the box.\n\n Returns:\n (np.ndarray): The cropped and resized square image.\n \"\"\"\n x1, y1, x2, y2 = map(int, box)\n w, h = x2 - x1, y2 - y1\n\n # Add margin\n margin_x, margin_y = int(w * margin_percent / 100), int(h * margin_percent / 100)\n x1, y1 = max(0, x1 - margin_x), max(0, y1 - margin_y)\n x2, y2 = min(frame.shape[1], x2 + margin_x), min(frame.shape[0], y2 + margin_y)\n\n # Take square crop from frame\n size = max(y2 - y1, x2 - x1)\n center_y, center_x = (y1 + y2) // 2, (x1 + x2) // 2\n half_size = size // 2\n square_crop = frame[\n max(0, center_y - half_size) : min(frame.shape[0], center_y + half_size),\n max(0, center_x - half_size) : min(frame.shape[1], center_x + half_size),\n ]\n\n return cv2.resize(square_crop, (224, 224), interpolation=cv2.INTER_LINEAR)", "chunk_type": "function", "name": "crop_and_pad", "file_path": "ultralytics\\examples\\YOLOv8-Action-Recognition\\action_recognition.py", "start_line": 291, "end_line": 320, "start_col": 0, "end_col": 78, "parent_name": null, "docstring": "Crop box with margin and take square crop from frame.\n\nArgs:\n frame (np.ndarray): The input frame to crop from.\n box (List[float]): The bounding box coordinates [x1, y1, x2, y2].\n margin_percent (int): The percentage of margin to add around the box.\n\nReturns:\n (np.ndarray): The cropped and resized square image.", "parameters": [ "frame: np.ndarray", "box: List[float]", "margin_percent: int" ], "return_type": "np.ndarray", "decorators": [], "complexity_score": 1, "dependencies": [ "argparse", "time", "collections.defaultdict", "typing.List", "typing.Optional", "typing.Tuple", "urllib.parse.urlparse", "cv2", "numpy", "torch", "transformers.AutoModel", "transformers.AutoProcessor", "ultralytics.YOLO", "ultralytics.data.loaders.get_best_youtube_url", "ultralytics.utils.plotting.Annotator", "ultralytics.utils.torch_utils.select_device", "torchvision.models.video.MViT_V1_B_Weights", "torchvision.models.video.MViT_V2_S_Weights", "torchvision.models.video.R3D_18_Weights", "torchvision.models.video.S3D_Weights", "torchvision.models.video.Swin3D_B_Weights", "torchvision.models.video.Swin3D_T_Weights", "torchvision.models.video.mvit_v1_b", "torchvision.models.video.mvit_v2_s", "torchvision.models.video.r3d_18", "torchvision.models.video.s3d", "torchvision.models.video.swin3d_b", "torchvision.models.video.swin3d_t", "torchvision.transforms.v2", "torchvision.transforms" ], "chunk_id": "function_crop_and_pad_918a3778" }, { "content": "def run(\n weights: str = \"yolo11n.pt\",\n device: str = \"\",\n source: str = \"https://www.youtube.com/watch?v=dQw4w9WgXcQ\",\n output_path: Optional[str] = None,\n crop_margin_percentage: int = 10,\n num_video_sequence_samples: int = 8,\n skip_frame: int = 2,\n video_cls_overlap_ratio: float = 0.25,\n fp16: bool = False,\n video_classifier_model: str = \"microsoft/xclip-base-patch32\",\n labels: List[str] = None,\n) -> None:\n \"\"\"\n Run action recognition on a video source using YOLO for object detection and a video classifier.\n\n Args:\n weights (str): Path to the YOLO model weights.\n device (str): Device to run the model on. Use 'cuda' for NVIDIA GPU, 'mps' for Apple Silicon, or 'cpu'.\n source (str): Path to mp4 video file or YouTube URL.\n output_path (str, optional): Path to save the output video.\n crop_margin_percentage (int): Percentage of margin to add around detected objects.\n num_video_sequence_samples (int): Number of video frames to use for classification.\n skip_frame (int): Number of frames to skip between detections.\n video_cls_overlap_ratio (float): Overlap ratio between video sequences.\n fp16 (bool): Whether to use half-precision floating point.\n video_classifier_model (str): Name or path of the video classifier model.\n labels (List[str], optional): List of labels for zero-shot classification.\n \"\"\"\n if labels is None:\n labels = [\n \"walking\",\n \"running\",\n \"brushing teeth\",\n \"looking into phone\",\n \"weight lifting\",\n \"cooking\",\n \"sitting\",\n ]\n # Initialize models and device\n device = select_device(device)\n yolo_model = YOLO(weights).to(device)\n if video_classifier_model in TorchVisionVideoClassifier.available_model_names():\n print(\"'fp16' is not supported for TorchVisionVideoClassifier. Setting fp16 to False.\")\n print(\n \"'labels' is not used for TorchVisionVideoClassifier. Ignoring the provided labels and using Kinetics-400 labels.\"\n )\n video_classifier = TorchVisionVideoClassifier(video_classifier_model, device=device)\n else:\n video_classifier = HuggingFaceVideoClassifier(\n labels, model_name=video_classifier_model, device=device, fp16=fp16\n )\n\n # Initialize video capture\n if source.startswith(\"http\") and urlparse(source).hostname in {\"www.youtube.com\", \"youtube.com\", \"youtu.be\"}:\n source = get_best_youtube_url(source)\n elif not source.endswith(\".mp4\"):\n raise ValueError(\"Invalid source. Supported sources are YouTube URLs and MP4 files.\")\n cap = cv2.VideoCapture(source)\n\n # Get video properties\n frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))\n frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))\n fps = cap.get(cv2.CAP_PROP_FPS)\n\n # Initialize VideoWriter\n if output_path is not None:\n fourcc = cv2.VideoWriter_fourcc(*\"mp4v\")\n out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height))\n\n # Initialize track history\n track_history = defaultdict(list)\n frame_counter = 0\n\n track_ids_to_infer = []\n crops_to_infer = []\n pred_labels = []\n pred_confs = []\n\n while cap.isOpened():\n success, frame = cap.read()\n if not success:\n break\n\n frame_counter += 1\n\n # Run YOLO tracking\n results = yolo_model.track(frame, persist=True, classes=[0]) # Track only person class\n\n if results[0].boxes.is_track:\n boxes = results[0].boxes.xyxy.cpu().numpy()\n track_ids = results[0].boxes.id.cpu().numpy()\n\n # Visualize prediction\n annotator = Annotator(frame, line_width=3, font_size=10, pil=False)\n\n if frame_counter % skip_frame == 0:\n crops_to_infer = []\n track_ids_to_infer = []\n\n for box, track_id in zip(boxes, track_ids):\n if frame_counter % skip_frame == 0:\n crop = crop_and_pad(frame, box, crop_margin_percentage)\n track_history[track_id].append(crop)\n\n if len(track_history[track_id]) > num_video_sequence_samples:\n track_history[track_id].pop(0)\n\n if len(track_history[track_id]) == num_video_sequence_samples and frame_counter % skip_frame == 0:\n start_time = time.time()\n crops = video_classifier.preprocess_crops_for_video_cls(track_history[track_id])\n end_time = time.time()\n preprocess_time = end_time - start_time\n print(f\"video cls preprocess time: {preprocess_time:.4f} seconds\")\n crops_to_infer.append(crops)\n track_ids_to_infer.append(track_id)\n\n if crops_to_infer and (\n not pred_labels\n or frame_counter % int(num_video_sequence_samples * skip_frame * (1 - video_cls_overlap_ratio)) == 0\n ):\n crops_batch = torch.cat(crops_to_infer, dim=0)\n\n start_inference_time = time.time()\n output_batch = video_classifier(crops_batch)\n end_inference_time = time.time()\n inference_time = end_inference_time - start_inference_time\n print(f\"video cls inference time: {inference_time:.4f} seconds\")\n\n pred_labels, pred_confs = video_classifier.postprocess(output_batch)\n\n if track_ids_to_infer and crops_to_infer:\n for box, track_id, pred_label, pred_conf in zip(boxes, track_ids_to_infer, pred_labels, pred_confs):\n top2_preds = sorted(zip(pred_label, pred_conf), key=lambda x: x[1], reverse=True)\n label_text = \" | \".join([f\"{label} ({conf:.2f})\" for label, conf in top2_preds])\n annotator.box_label(box, label_text, color=(0, 0, 255))\n\n # Write the annotated frame to the output video\n if output_path is not None:\n out.write(frame)\n\n # Display the annotated frame\n cv2.imshow(\"YOLOv8 Tracking with S3D Classification\", frame)\n\n if cv2.waitKey(1) & 0xFF == ord(\"q\"):\n break\n\n cap.release()\n if output_path is not None:\n out.release()\n cv2.destroyAllWindows()", "chunk_type": "function", "name": "run", "file_path": "ultralytics\\examples\\YOLOv8-Action-Recognition\\action_recognition.py", "start_line": 323, "end_line": 473, "start_col": 0, "end_col": 27, "parent_name": null, "docstring": "Run action recognition on a video source using YOLO for object detection and a video classifier.\n\nArgs:\n weights (str): Path to the YOLO model weights.\n device (str): Device to run the model on. Use 'cuda' for NVIDIA GPU, 'mps' for Apple Silicon, or 'cpu'.\n source (str): Path to mp4 video file or YouTube URL.\n output_path (str, optional): Path to save the output video.\n crop_margin_percentage (int): Percentage of margin to add around detected objects.\n num_video_sequence_samples (int): Number of video frames to use for classification.\n skip_frame (int): Number of frames to skip between detections.\n video_cls_overlap_ratio (float): Overlap ratio between video sequences.\n fp16 (bool): Whether to use half-precision floating point.\n video_classifier_model (str): Name or path of the video classifier model.\n labels (List[str], optional): List of labels for zero-shot classification.", "parameters": [ "weights: str", "device: str", "source: str", "output_path: Optional[str]", "crop_margin_percentage: int", "num_video_sequence_samples: int", "skip_frame: int", "video_cls_overlap_ratio: float", "fp16: bool", "video_classifier_model: str", "labels: List[str]" ], "return_type": "None", "decorators": [], "complexity_score": 21, "dependencies": [ "argparse", "time", "collections.defaultdict", "typing.List", "typing.Optional", "typing.Tuple", "urllib.parse.urlparse", "cv2", "numpy", "torch", "transformers.AutoModel", "transformers.AutoProcessor", "ultralytics.YOLO", "ultralytics.data.loaders.get_best_youtube_url", "ultralytics.utils.plotting.Annotator", "ultralytics.utils.torch_utils.select_device", "torchvision.models.video.MViT_V1_B_Weights", "torchvision.models.video.MViT_V2_S_Weights", "torchvision.models.video.R3D_18_Weights", "torchvision.models.video.S3D_Weights", "torchvision.models.video.Swin3D_B_Weights", "torchvision.models.video.Swin3D_T_Weights", "torchvision.models.video.mvit_v1_b", "torchvision.models.video.mvit_v2_s", "torchvision.models.video.r3d_18", "torchvision.models.video.s3d", "torchvision.models.video.swin3d_b", "torchvision.models.video.swin3d_t", "torchvision.transforms.v2", "torchvision.transforms" ], "chunk_id": "function_run_3f9c1df8" }, { "content": "def parse_opt() -> argparse.Namespace:\n \"\"\"Parse command line arguments for action recognition pipeline.\"\"\"\n parser = argparse.ArgumentParser()\n parser.add_argument(\"--weights\", type=str, default=\"yolo11n.pt\", help=\"ultralytics detector model path\")\n parser.add_argument(\"--device\", default=\"\", help='cuda device, i.e. 0 or 0,1,2,3 or cpu/mps, \"\" for auto-detection')\n parser.add_argument(\n \"--source\",\n type=str,\n default=\"https://www.youtube.com/watch?v=dQw4w9WgXcQ\",\n help=\"video file path or youtube URL\",\n )\n parser.add_argument(\"--output-path\", type=str, default=\"output_video.mp4\", help=\"output video file path\")\n parser.add_argument(\n \"--crop-margin-percentage\", type=int, default=10, help=\"percentage of margin to add around detected objects\"\n )\n parser.add_argument(\n \"--num-video-sequence-samples\", type=int, default=8, help=\"number of video frames to use for classification\"\n )\n parser.add_argument(\"--skip-frame\", type=int, default=2, help=\"number of frames to skip between detections\")\n parser.add_argument(\n \"--video-cls-overlap-ratio\", type=float, default=0.25, help=\"overlap ratio between video sequences\"\n )\n parser.add_argument(\"--fp16\", action=\"store_true\", help=\"use FP16 for inference\")\n parser.add_argument(\n \"--video-classifier-model\", type=str, default=\"microsoft/xclip-base-patch32\", help=\"video classifier model name\"\n )\n parser.add_argument(\n \"--labels\",\n nargs=\"+\",\n type=str,\n default=[\"dancing\", \"singing a song\"],\n help=\"labels for zero-shot video classification\",\n )\n return parser.parse_args()", "chunk_type": "function", "name": "parse_opt", "file_path": "ultralytics\\examples\\YOLOv8-Action-Recognition\\action_recognition.py", "start_line": 476, "end_line": 509, "start_col": 0, "end_col": 30, "parent_name": null, "docstring": "Parse command line arguments for action recognition pipeline.", "parameters": [], "return_type": "argparse.Namespace", "decorators": [], "complexity_score": 1, "dependencies": [ "argparse", "time", "collections.defaultdict", "typing.List", "typing.Optional", "typing.Tuple", "urllib.parse.urlparse", "cv2", "numpy", "torch", "transformers.AutoModel", "transformers.AutoProcessor", "ultralytics.YOLO", "ultralytics.data.loaders.get_best_youtube_url", "ultralytics.utils.plotting.Annotator", "ultralytics.utils.torch_utils.select_device", "torchvision.models.video.MViT_V1_B_Weights", "torchvision.models.video.MViT_V2_S_Weights", "torchvision.models.video.R3D_18_Weights", "torchvision.models.video.S3D_Weights", "torchvision.models.video.Swin3D_B_Weights", "torchvision.models.video.Swin3D_T_Weights", "torchvision.models.video.mvit_v1_b", "torchvision.models.video.mvit_v2_s", "torchvision.models.video.r3d_18", "torchvision.models.video.s3d", "torchvision.models.video.swin3d_b", "torchvision.models.video.swin3d_t", "torchvision.transforms.v2", "torchvision.transforms" ], "chunk_id": "function_parse_opt_98682cf0" }, { "content": "def main(opt: argparse.Namespace) -> None:\n \"\"\"Run the action recognition pipeline with parsed command line arguments.\"\"\"\n run(**vars(opt))", "chunk_type": "function", "name": "main", "file_path": "ultralytics\\examples\\YOLOv8-Action-Recognition\\action_recognition.py", "start_line": 512, "end_line": 514, "start_col": 0, "end_col": 20, "parent_name": null, "docstring": "Run the action recognition pipeline with parsed command line arguments.", "parameters": [ "opt: argparse.Namespace" ], "return_type": "None", "decorators": [], "complexity_score": 1, "dependencies": [ "argparse", "time", "collections.defaultdict", "typing.List", "typing.Optional", "typing.Tuple", "urllib.parse.urlparse", "cv2", "numpy", "torch", "transformers.AutoModel", "transformers.AutoProcessor", "ultralytics.YOLO", "ultralytics.data.loaders.get_best_youtube_url", "ultralytics.utils.plotting.Annotator", "ultralytics.utils.torch_utils.select_device", "torchvision.models.video.MViT_V1_B_Weights", "torchvision.models.video.MViT_V2_S_Weights", "torchvision.models.video.R3D_18_Weights", "torchvision.models.video.S3D_Weights", "torchvision.models.video.Swin3D_B_Weights", "torchvision.models.video.Swin3D_T_Weights", "torchvision.models.video.mvit_v1_b", "torchvision.models.video.mvit_v2_s", "torchvision.models.video.r3d_18", "torchvision.models.video.s3d", "torchvision.models.video.swin3d_b", "torchvision.models.video.swin3d_t", "torchvision.transforms.v2", "torchvision.transforms" ], "chunk_id": "function_main_38ac49f4" }, { "content": "import argparse", "chunk_type": "import", "name": "argparse", "file_path": "ultralytics\\examples\\YOLOv8-ONNXRuntime\\main.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 15, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_argparse_7d185a70" }, { "content": "from typing import List, Tuple", "chunk_type": "import", "name": "List, Tuple", "file_path": "ultralytics\\examples\\YOLOv8-ONNXRuntime\\main.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 30, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_List, Tuple_b5ec66ab" }, { "content": "import cv2", "chunk_type": "import", "name": "cv2", "file_path": "ultralytics\\examples\\YOLOv8-ONNXRuntime\\main.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 10, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_cv2_dbd66edf" }, { "content": "import numpy as np", "chunk_type": "import", "name": "numpy", "file_path": "ultralytics\\examples\\YOLOv8-ONNXRuntime\\main.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_numpy_667e6aba" }, { "content": "import onnxruntime as ort", "chunk_type": "import", "name": "onnxruntime", "file_path": "ultralytics\\examples\\YOLOv8-ONNXRuntime\\main.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 25, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_onnxruntime_7c0af90b" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\examples\\YOLOv8-ONNXRuntime\\main.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_b4e1bc9f" }, { "content": "from ultralytics.utils import ASSETS, YAML", "chunk_type": "import", "name": "ASSETS, YAML", "file_path": "ultralytics\\examples\\YOLOv8-ONNXRuntime\\main.py", "start_line": 11, "end_line": 11, "start_col": 0, "end_col": 42, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_ASSETS, YAML_edee6aeb" }, { "content": "from ultralytics.utils.checks import check_requirements, check_yaml", "chunk_type": "import", "name": "check_requirements, check_yaml", "file_path": "ultralytics\\examples\\YOLOv8-ONNXRuntime\\main.py", "start_line": 12, "end_line": 12, "start_col": 0, "end_col": 67, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_check_requirements, check_yaml_2c2925f3" }, { "content": "class YOLOv8:\n \"\"\"\n YOLOv8 object detection model class for handling ONNX inference and visualization.\n\n This class provides functionality to load a YOLOv8 ONNX model, perform inference on images,\n and visualize the detection results with bounding boxes and labels.\n\n Attributes:\n onnx_model (str): Path to the ONNX model file.\n input_image (str): Path to the input image file.\n confidence_thres (float): Confidence threshold for filtering detections.\n iou_thres (float): IoU threshold for non-maximum suppression.\n classes (List[str]): List of class names from the COCO dataset.\n color_palette (np.ndarray): Random color palette for visualizing different classes.\n input_width (int): Width dimension of the model input.\n input_height (int): Height dimension of the model input.\n img (np.ndarray): The loaded input image.\n img_height (int): Height of the input image.\n img_width (int): Width of the input image.\n\n Methods:\n letterbox: Resize and reshape images while maintaining aspect ratio by adding padding.\n draw_detections: Draw bounding boxes and labels on the input image based on detected objects.\n preprocess: Preprocess the input image before performing inference.\n postprocess: Perform post-processing on the model's output to extract and visualize detections.\n main: Perform inference using an ONNX model and return the output image with drawn detections.\n\n Examples:\n Initialize YOLOv8 detector and run inference\n >>> detector = YOLOv8(\"yolov8n.onnx\", \"image.jpg\", 0.5, 0.5)\n >>> output_image = detector.main()\n \"\"\"\n\n def __init__(self, onnx_model: str, input_image: str, confidence_thres: float, iou_thres: float):\n \"\"\"\n Initialize an instance of the YOLOv8 class.\n\n Args:\n onnx_model (str): Path to the ONNX model.\n input_image (str): Path to the input image.\n confidence_thres (float): Confidence threshold for filtering detections.\n iou_thres (float): IoU threshold for non-maximum suppression.\n \"\"\"\n self.onnx_model = onnx_model\n self.input_image = input_image\n self.confidence_thres = confidence_thres\n self.iou_thres = iou_thres\n\n # Load the class names from the COCO dataset\n self.classes = YAML.load(check_yaml(\"coco8.yaml\"))[\"names\"]\n\n # Generate a color palette for the classes\n self.color_palette = np.random.uniform(0, 255, size=(len(self.classes), 3))\n\n def letterbox(self, img: np.ndarray, new_shape: Tuple[int, int] = (640, 640)) -> Tuple[np.ndarray, Tuple[int, int]]:\n \"\"\"\n Resize and reshape images while maintaining aspect ratio by adding padding.\n\n Args:\n img (np.ndarray): Input image to be resized.\n new_shape (Tuple[int, int]): Target shape (height, width) for the image.\n\n Returns:\n img (np.ndarray): Resized and padded image.\n pad (Tuple[int, int]): Padding values (top, left) applied to the image.\n \"\"\"\n shape = img.shape[:2] # current shape [height, width]\n\n # Scale ratio (new / old)\n r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])\n\n # Compute padding\n new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))\n dw, dh = (new_shape[1] - new_unpad[0]) / 2, (new_shape[0] - new_unpad[1]) / 2 # wh padding\n\n if shape[::-1] != new_unpad: # resize\n img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)\n top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))\n left, right = int(round(dw - 0.1)), int(round(dw + 0.1))\n img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114))\n\n return img, (top, left)\n\n def draw_detections(self, img: np.ndarray, box: List[float], score: float, class_id: int) -> None:\n \"\"\"Draw bounding boxes and labels on the input image based on the detected objects.\"\"\"\n # Extract the coordinates of the bounding box\n x1, y1, w, h = box\n\n # Retrieve the color for the class ID\n color = self.color_palette[class_id]\n\n # Draw the bounding box on the image\n cv2.rectangle(img, (int(x1), int(y1)), (int(x1 + w), int(y1 + h)), color, 2)\n\n # Create the label text with class name and score\n label = f\"{self.classes[class_id]}: {score:.2f}\"\n\n # Calculate the dimensions of the label text\n (label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)\n\n # Calculate the position of the label text\n label_x = x1\n label_y = y1 - 10 if y1 - 10 > label_height else y1 + 10\n\n # Draw a filled rectangle as the background for the label text\n cv2.rectangle(\n img, (label_x, label_y - label_height), (label_x + label_width, label_y + label_height), color, cv2.FILLED\n )\n\n # Draw the label text on the image\n cv2.putText(img, label, (label_x, label_y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)\n\n def preprocess(self) -> Tuple[np.ndarray, Tuple[int, int]]:\n \"\"\"\n Preprocess the input image before performing inference.\n\n This method reads the input image, converts its color space, applies letterboxing to maintain aspect ratio,\n normalizes pixel values, and prepares the image data for model input.\n\n Returns:\n image_data (np.ndarray): Preprocessed image data ready for inference with shape (1, 3, height, width).\n pad (Tuple[int, int]): Padding values (top, left) applied during letterboxing.\n \"\"\"\n # Read the input image using OpenCV\n self.img = cv2.imread(self.input_image)\n\n # Get the height and width of the input image\n self.img_height, self.img_width = self.img.shape[:2]\n\n # Convert the image color space from BGR to RGB\n img = cv2.cvtColor(self.img, cv2.COLOR_BGR2RGB)\n\n img, pad = self.letterbox(img, (self.input_width, self.input_height))\n\n # Normalize the image data by dividing it by 255.0\n image_data = np.array(img) / 255.0\n\n # Transpose the image to have the channel dimension as the first dimension\n image_data = np.transpose(image_data, (2, 0, 1)) # Channel first\n\n # Expand the dimensions of the image data to match the expected input shape\n image_data = np.expand_dims(image_data, axis=0).astype(np.float32)\n\n # Return the preprocessed image data\n return image_data, pad\n\n def postprocess(self, input_image: np.ndarray, output: List[np.ndarray], pad: Tuple[int, int]) -> np.ndarray:\n \"\"\"\n Perform post-processing on the model's output to extract and visualize detections.\n\n This method processes the raw model output to extract bounding boxes, scores, and class IDs.\n It applies non-maximum suppression to filter overlapping detections and draws the results on the input image.\n\n Args:\n input_image (np.ndarray): The input image.\n output (List[np.ndarray]): The output arrays from the model.\n pad (Tuple[int, int]): Padding values (top, left) used during letterboxing.\n\n Returns:\n (np.ndarray): The input image with detections drawn on it.\n \"\"\"\n # Transpose and squeeze the output to match the expected shape\n outputs = np.transpose(np.squeeze(output[0]))\n\n # Get the number of rows in the outputs array\n rows = outputs.shape[0]\n\n # Lists to store the bounding boxes, scores, and class IDs of the detections\n boxes = []\n scores = []\n class_ids = []\n\n # Calculate the scaling factors for the bounding box coordinates\n gain = min(self.input_height / self.img_height, self.input_width / self.img_width)\n outputs[:, 0] -= pad[1]\n outputs[:, 1] -= pad[0]\n\n # Iterate over each row in the outputs array\n for i in range(rows):\n # Extract the class scores from the current row\n classes_scores = outputs[i][4:]\n\n # Find the maximum score among the class scores\n max_score = np.amax(classes_scores)\n\n # If the maximum score is above the confidence threshold\n if max_score >= self.confidence_thres:\n # Get the class ID with the highest score\n class_id = np.argmax(classes_scores)\n\n # Extract the bounding box coordinates from the current row\n x, y, w, h = outputs[i][0], outputs[i][1], outputs[i][2], outputs[i][3]\n\n # Calculate the scaled coordinates of the bounding box\n left = int((x - w / 2) / gain)\n top = int((y - h / 2) / gain)\n width = int(w / gain)\n height = int(h / gain)\n\n # Add the class ID, score, and box coordinates to the respective lists\n class_ids.append(class_id)\n scores.append(max_score)\n boxes.append([left, top, width, height])\n\n # Apply non-maximum suppression to filter out overlapping bounding boxes\n indices = cv2.dnn.NMSBoxes(boxes, scores, self.confidence_thres, self.iou_thres)\n\n # Iterate over the selected indices after non-maximum suppression\n for i in indices:\n # Get the box, score, and class ID corresponding to the index\n box = boxes[i]\n score = scores[i]\n class_id = class_ids[i]\n\n # Draw the detection on the input image\n self.draw_detections(input_image, box, score, class_id)\n\n # Return the modified input image\n return input_image\n\n def main(self) -> np.ndarray:\n \"\"\"\n Perform inference using an ONNX model and return the output image with drawn detections.\n\n Returns:\n (np.ndarray): The output image with drawn detections.\n \"\"\"\n # Create an inference session using the ONNX model and specify execution providers\n session = ort.InferenceSession(self.onnx_model, providers=[\"CUDAExecutionProvider\", \"CPUExecutionProvider\"])\n\n # Get the model inputs\n model_inputs = session.get_inputs()\n\n # Store the shape of the input for later use\n input_shape = model_inputs[0].shape\n self.input_width = input_shape[2]\n self.input_height = input_shape[3]\n\n # Preprocess the image data\n img_data, pad = self.preprocess()\n\n # Run inference using the preprocessed image data\n outputs = session.run(None, {model_inputs[0].name: img_data})\n\n # Perform post-processing on the outputs to obtain output image\n return self.postprocess(self.img, outputs, pad)", "chunk_type": "class", "name": "YOLOv8", "file_path": "ultralytics\\examples\\YOLOv8-ONNXRuntime\\main.py", "start_line": 15, "end_line": 260, "start_col": 0, "end_col": 55, "parent_name": null, "docstring": "YOLOv8 object detection model class for handling ONNX inference and visualization.\n\nThis class provides functionality to load a YOLOv8 ONNX model, perform inference on images,\nand visualize the detection results with bounding boxes and labels.\n\nAttributes:\n onnx_model (str): Path to the ONNX model file.\n input_image (str): Path to the input image file.\n confidence_thres (float): Confidence threshold for filtering detections.\n iou_thres (float): IoU threshold for non-maximum suppression.\n classes (List[str]): List of class names from the COCO dataset.\n color_palette (np.ndarray): Random color palette for visualizing different classes.\n input_width (int): Width dimension of the model input.\n input_height (int): Height dimension of the model input.\n img (np.ndarray): The loaded input image.\n img_height (int): Height of the input image.\n img_width (int): Width of the input image.\n\nMethods:\n letterbox: Resize and reshape images while maintaining aspect ratio by adding padding.\n draw_detections: Draw bounding boxes and labels on the input image based on detected objects.\n preprocess: Preprocess the input image before performing inference.\n postprocess: Perform post-processing on the model's output to extract and visualize detections.\n main: Perform inference using an ONNX model and return the output image with drawn detections.\n\nExamples:\n Initialize YOLOv8 detector and run inference\n >>> detector = YOLOv8(\"yolov8n.onnx\", \"image.jpg\", 0.5, 0.5)\n >>> output_image = detector.main()", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "argparse", "typing.List", "typing.Tuple", "cv2", "numpy", "onnxruntime", "torch", "ultralytics.utils.ASSETS", "ultralytics.utils.YAML", "ultralytics.utils.checks.check_requirements", "ultralytics.utils.checks.check_yaml" ], "chunk_id": "class_YOLOv8_81edbfd6" }, { "content": "import argparse", "chunk_type": "import", "name": "argparse", "file_path": "ultralytics\\examples\\YOLOv8-OpenCV-ONNX-Python\\main.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 15, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_argparse_c1d35bf5" }, { "content": "from typing import Any, Dict, List", "chunk_type": "import", "name": "Any, Dict, List", "file_path": "ultralytics\\examples\\YOLOv8-OpenCV-ONNX-Python\\main.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 34, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Any, Dict, List_0370a665" }, { "content": "import cv2.dnn", "chunk_type": "import", "name": "cv2.dnn", "file_path": "ultralytics\\examples\\YOLOv8-OpenCV-ONNX-Python\\main.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 14, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_cv2.dnn_d8085c4d" }, { "content": "import numpy as np", "chunk_type": "import", "name": "numpy", "file_path": "ultralytics\\examples\\YOLOv8-OpenCV-ONNX-Python\\main.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_numpy_cb82ea54" }, { "content": "from ultralytics.utils import ASSETS, YAML", "chunk_type": "import", "name": "ASSETS, YAML", "file_path": "ultralytics\\examples\\YOLOv8-OpenCV-ONNX-Python\\main.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 42, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_ASSETS, YAML_bbda73ad" }, { "content": "from ultralytics.utils.checks import check_yaml", "chunk_type": "import", "name": "check_yaml", "file_path": "ultralytics\\examples\\YOLOv8-OpenCV-ONNX-Python\\main.py", "start_line": 10, "end_line": 10, "start_col": 0, "end_col": 47, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_check_yaml_f9ffeb54" }, { "content": "CLASSES = YAML.load(check_yaml(\"coco8.yaml\"))[\"names\"]", "chunk_type": "variable", "name": "CLASSES", "file_path": "ultralytics\\examples\\YOLOv8-OpenCV-ONNX-Python\\main.py", "start_line": 12, "end_line": 12, "start_col": 0, "end_col": 54, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_CLASSES_b7dde858" }, { "content": "colors = np.random.uniform(0, 255, size=(len(CLASSES), 3))", "chunk_type": "variable", "name": "colors", "file_path": "ultralytics\\examples\\YOLOv8-OpenCV-ONNX-Python\\main.py", "start_line": 13, "end_line": 13, "start_col": 0, "end_col": 58, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_colors_317a7e72" }, { "content": "def draw_bounding_box(\n img: np.ndarray, class_id: int, confidence: float, x: int, y: int, x_plus_w: int, y_plus_h: int\n) -> None:\n \"\"\"\n Draw bounding boxes on the input image based on the provided arguments.\n\n Args:\n img (np.ndarray): The input image to draw the bounding box on.\n class_id (int): Class ID of the detected object.\n confidence (float): Confidence score of the detected object.\n x (int): X-coordinate of the top-left corner of the bounding box.\n y (int): Y-coordinate of the top-left corner of the bounding box.\n x_plus_w (int): X-coordinate of the bottom-right corner of the bounding box.\n y_plus_h (int): Y-coordinate of the bottom-right corner of the bounding box.\n \"\"\"\n label = f\"{CLASSES[class_id]} ({confidence:.2f})\"\n color = colors[class_id]\n cv2.rectangle(img, (x, y), (x_plus_w, y_plus_h), color, 2)\n cv2.putText(img, label, (x - 10, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)", "chunk_type": "function", "name": "draw_bounding_box", "file_path": "ultralytics\\examples\\YOLOv8-OpenCV-ONNX-Python\\main.py", "start_line": 16, "end_line": 34, "start_col": 0, "end_col": 86, "parent_name": null, "docstring": "Draw bounding boxes on the input image based on the provided arguments.\n\nArgs:\n img (np.ndarray): The input image to draw the bounding box on.\n class_id (int): Class ID of the detected object.\n confidence (float): Confidence score of the detected object.\n x (int): X-coordinate of the top-left corner of the bounding box.\n y (int): Y-coordinate of the top-left corner of the bounding box.\n x_plus_w (int): X-coordinate of the bottom-right corner of the bounding box.\n y_plus_h (int): Y-coordinate of the bottom-right corner of the bounding box.", "parameters": [ "img: np.ndarray", "class_id: int", "confidence: float", "x: int", "y: int", "x_plus_w: int", "y_plus_h: int" ], "return_type": "None", "decorators": [], "complexity_score": 1, "dependencies": [ "argparse", "typing.Any", "typing.Dict", "typing.List", "cv2.dnn", "numpy", "ultralytics.utils.ASSETS", "ultralytics.utils.YAML", "ultralytics.utils.checks.check_yaml" ], "chunk_id": "function_draw_bounding_box_5eae9392" }, { "content": "def main(onnx_model: str, input_image: str) -> List[Dict[str, Any]]:\n \"\"\"\n Load ONNX model, perform inference, draw bounding boxes, and display the output image.\n\n Args:\n onnx_model (str): Path to the ONNX model.\n input_image (str): Path to the input image.\n\n Returns:\n (List[Dict[str, Any]]): List of dictionaries containing detection information such as class_id, class_name,\n confidence, box coordinates, and scale factor.\n \"\"\"\n # Load the ONNX model\n model: cv2.dnn.Net = cv2.dnn.readNetFromONNX(onnx_model)\n\n # Read the input image\n original_image: np.ndarray = cv2.imread(input_image)\n [height, width, _] = original_image.shape\n\n # Prepare a square image for inference\n length = max((height, width))\n image = np.zeros((length, length, 3), np.uint8)\n image[0:height, 0:width] = original_image\n\n # Calculate scale factor\n scale = length / 640\n\n # Preprocess the image and prepare blob for model\n blob = cv2.dnn.blobFromImage(image, scalefactor=1 / 255, size=(640, 640), swapRB=True)\n model.setInput(blob)\n\n # Perform inference\n outputs = model.forward()\n\n # Prepare output array\n outputs = np.array([cv2.transpose(outputs[0])])\n rows = outputs.shape[1]\n\n boxes = []\n scores = []\n class_ids = []\n\n # Iterate through output to collect bounding boxes, confidence scores, and class IDs\n for i in range(rows):\n classes_scores = outputs[0][i][4:]\n (minScore, maxScore, minClassLoc, (x, maxClassIndex)) = cv2.minMaxLoc(classes_scores)\n if maxScore >= 0.25:\n box = [\n outputs[0][i][0] - (0.5 * outputs[0][i][2]), # x center - width/2 = left x\n outputs[0][i][1] - (0.5 * outputs[0][i][3]), # y center - height/2 = top y\n outputs[0][i][2], # width\n outputs[0][i][3], # height\n ]\n boxes.append(box)\n scores.append(maxScore)\n class_ids.append(maxClassIndex)\n\n # Apply NMS (Non-maximum suppression)\n result_boxes = cv2.dnn.NMSBoxes(boxes, scores, 0.25, 0.45, 0.5)\n\n detections = []\n\n # Iterate through NMS results to draw bounding boxes and labels\n for i in range(len(result_boxes)):\n index = result_boxes[i]\n box = boxes[index]\n detection = {\n \"class_id\": class_ids[index],\n \"class_name\": CLASSES[class_ids[index]],\n \"confidence\": scores[index],\n \"box\": box,\n \"scale\": scale,\n }\n detections.append(detection)\n draw_bounding_box(\n original_image,\n class_ids[index],\n scores[index],\n round(box[0] * scale),\n round(box[1] * scale),\n round((box[0] + box[2]) * scale),\n round((box[1] + box[3]) * scale),\n )\n\n # Display the image with bounding boxes\n cv2.imshow(\"image\", original_image)\n cv2.waitKey(0)\n cv2.destroyAllWindows()\n\n return detections", "chunk_type": "function", "name": "main", "file_path": "ultralytics\\examples\\YOLOv8-OpenCV-ONNX-Python\\main.py", "start_line": 37, "end_line": 126, "start_col": 0, "end_col": 21, "parent_name": null, "docstring": "Load ONNX model, perform inference, draw bounding boxes, and display the output image.\n\nArgs:\n onnx_model (str): Path to the ONNX model.\n input_image (str): Path to the input image.\n\nReturns:\n (List[Dict[str, Any]]): List of dictionaries containing detection information such as class_id, class_name,\n confidence, box coordinates, and scale factor.", "parameters": [ "onnx_model: str", "input_image: str" ], "return_type": "List[Dict[str, Any]]", "decorators": [], "complexity_score": 4, "dependencies": [ "argparse", "typing.Any", "typing.Dict", "typing.List", "cv2.dnn", "numpy", "ultralytics.utils.ASSETS", "ultralytics.utils.YAML", "ultralytics.utils.checks.check_yaml" ], "chunk_id": "function_main_0d626e9e" }, { "content": "import argparse", "chunk_type": "import", "name": "argparse", "file_path": "ultralytics\\examples\\YOLOv8-Region-Counter\\yolov8_region_counter.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 15, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_argparse_e13875bf" }, { "content": "from collections import defaultdict", "chunk_type": "import", "name": "defaultdict", "file_path": "ultralytics\\examples\\YOLOv8-Region-Counter\\yolov8_region_counter.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 35, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_defaultdict_c3466af0" }, { "content": "from pathlib import Path", "chunk_type": "import", "name": "Path", "file_path": "ultralytics\\examples\\YOLOv8-Region-Counter\\yolov8_region_counter.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Path_74f4f86f" }, { "content": "from typing import Any, List", "chunk_type": "import", "name": "Any, List", "file_path": "ultralytics\\examples\\YOLOv8-Region-Counter\\yolov8_region_counter.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 28, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Any, List_390698c0" }, { "content": "import cv2", "chunk_type": "import", "name": "cv2", "file_path": "ultralytics\\examples\\YOLOv8-Region-Counter\\yolov8_region_counter.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 10, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_cv2_8ccdbaf0" }, { "content": "import numpy as np", "chunk_type": "import", "name": "numpy", "file_path": "ultralytics\\examples\\YOLOv8-Region-Counter\\yolov8_region_counter.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_numpy_5660fb03" }, { "content": "from shapely.geometry import Polygon", "chunk_type": "import", "name": "Polygon", "file_path": "ultralytics\\examples\\YOLOv8-Region-Counter\\yolov8_region_counter.py", "start_line": 10, "end_line": 10, "start_col": 0, "end_col": 36, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Polygon_d129a6e5" }, { "content": "from shapely.geometry.point import Point", "chunk_type": "import", "name": "Point", "file_path": "ultralytics\\examples\\YOLOv8-Region-Counter\\yolov8_region_counter.py", "start_line": 11, "end_line": 11, "start_col": 0, "end_col": 40, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Point_e2391ede" }, { "content": "from ultralytics import YOLO", "chunk_type": "import", "name": "YOLO", "file_path": "ultralytics\\examples\\YOLOv8-Region-Counter\\yolov8_region_counter.py", "start_line": 13, "end_line": 13, "start_col": 0, "end_col": 28, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_YOLO_005391d5" }, { "content": "from ultralytics.utils.files import increment_path", "chunk_type": "import", "name": "increment_path", "file_path": "ultralytics\\examples\\YOLOv8-Region-Counter\\yolov8_region_counter.py", "start_line": 14, "end_line": 14, "start_col": 0, "end_col": 50, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_increment_path_c2597402" }, { "content": "from ultralytics.utils.plotting import Annotator, colors", "chunk_type": "import", "name": "Annotator, colors", "file_path": "ultralytics\\examples\\YOLOv8-Region-Counter\\yolov8_region_counter.py", "start_line": 15, "end_line": 15, "start_col": 0, "end_col": 56, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Annotator, colors_55085f62" }, { "content": "track_history = defaultdict(list)", "chunk_type": "variable", "name": "track_history", "file_path": "ultralytics\\examples\\YOLOv8-Region-Counter\\yolov8_region_counter.py", "start_line": 17, "end_line": 17, "start_col": 0, "end_col": 33, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_track_history_1022ba69" }, { "content": "current_region = None", "chunk_type": "variable", "name": "current_region", "file_path": "ultralytics\\examples\\YOLOv8-Region-Counter\\yolov8_region_counter.py", "start_line": 19, "end_line": 19, "start_col": 0, "end_col": 21, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_current_region_822de894" }, { "content": "counting_regions = [\n {\n \"name\": \"Ultralytics YOLO Polygon Region\",\n \"polygon\": Polygon([(50, 80), (250, 20), (450, 80), (400, 350), (100, 350)]), # Polygon points\n \"counts\": 0,\n \"dragging\": False,\n \"region_color\": (255, 42, 4), # BGR Value\n \"text_color\": (255, 255, 255), # Region Text Color\n },\n {\n \"name\": \"Ultralytics YOLO Rectangle Region\",\n \"polygon\": Polygon([(200, 250), (440, 250), (440, 550), (200, 550)]), # Polygon points\n \"counts\": 0,\n \"dragging\": False,\n \"region_color\": (37, 255, 225), # BGR Value\n \"text_color\": (0, 0, 0), # Region Text Color\n },\n]", "chunk_type": "variable", "name": "counting_regions", "file_path": "ultralytics\\examples\\YOLOv8-Region-Counter\\yolov8_region_counter.py", "start_line": 20, "end_line": 37, "start_col": 0, "end_col": 1, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_counting_regions_a20a497f" }, { "content": "def mouse_callback(event: int, x: int, y: int, flags: int, param: Any) -> None:\n \"\"\"\n Handle mouse events for region manipulation in the video frame.\n\n This function enables interactive region selection and dragging functionality for counting regions. It responds to\n mouse button down, move, and up events to allow users to select and reposition counting regions in real-time.\n\n Args:\n event (int): The mouse event type (e.g., cv2.EVENT_LBUTTONDOWN, cv2.EVENT_MOUSEMOVE).\n x (int): The x-coordinate of the mouse pointer.\n y (int): The y-coordinate of the mouse pointer.\n flags (int): Additional flags passed by OpenCV.\n param (Any): Additional parameters passed to the callback.\n\n Examples:\n Set up mouse callback for interactive region manipulation\n >>> cv2.setMouseCallback(\"window_name\", mouse_callback)\n \"\"\"\n global current_region\n\n # Mouse left button down event\n if event == cv2.EVENT_LBUTTONDOWN:\n for region in counting_regions:\n if region[\"polygon\"].contains(Point((x, y))):\n current_region = region\n current_region[\"dragging\"] = True\n current_region[\"offset_x\"] = x\n current_region[\"offset_y\"] = y\n\n # Mouse move event\n elif event == cv2.EVENT_MOUSEMOVE:\n if current_region is not None and current_region[\"dragging\"]:\n dx = x - current_region[\"offset_x\"]\n dy = y - current_region[\"offset_y\"]\n current_region[\"polygon\"] = Polygon(\n [(p[0] + dx, p[1] + dy) for p in current_region[\"polygon\"].exterior.coords]\n )\n current_region[\"offset_x\"] = x\n current_region[\"offset_y\"] = y\n\n # Mouse left button up event\n elif event == cv2.EVENT_LBUTTONUP:\n if current_region is not None and current_region[\"dragging\"]:\n current_region[\"dragging\"] = False", "chunk_type": "function", "name": "mouse_callback", "file_path": "ultralytics\\examples\\YOLOv8-Region-Counter\\yolov8_region_counter.py", "start_line": 40, "end_line": 83, "start_col": 0, "end_col": 46, "parent_name": null, "docstring": "Handle mouse events for region manipulation in the video frame.\n\nThis function enables interactive region selection and dragging functionality for counting regions. It responds to\nmouse button down, move, and up events to allow users to select and reposition counting regions in real-time.\n\nArgs:\n event (int): The mouse event type (e.g., cv2.EVENT_LBUTTONDOWN, cv2.EVENT_MOUSEMOVE).\n x (int): The x-coordinate of the mouse pointer.\n y (int): The y-coordinate of the mouse pointer.\n flags (int): Additional flags passed by OpenCV.\n param (Any): Additional parameters passed to the callback.\n\nExamples:\n Set up mouse callback for interactive region manipulation\n >>> cv2.setMouseCallback(\"window_name\", mouse_callback)", "parameters": [ "event: int", "x: int", "y: int", "flags: int", "param: Any" ], "return_type": "None", "decorators": [], "complexity_score": 9, "dependencies": [ "argparse", "collections.defaultdict", "pathlib.Path", "typing.Any", "typing.List", "cv2", "numpy", "shapely.geometry.Polygon", "shapely.geometry.point.Point", "ultralytics.YOLO", "ultralytics.utils.files.increment_path", "ultralytics.utils.plotting.Annotator", "ultralytics.utils.plotting.colors" ], "chunk_id": "function_mouse_callback_94778a61" }, { "content": "def run(\n weights: str = \"yolo11n.pt\",\n source: str = None,\n device: str = \"cpu\",\n view_img: bool = False,\n save_img: bool = False,\n exist_ok: bool = False,\n classes: List[int] = None,\n line_thickness: int = 2,\n track_thickness: int = 2,\n region_thickness: int = 2,\n) -> None:\n \"\"\"\n Run object detection and counting within specified regions using YOLO and ByteTrack.\n\n This function performs real-time object detection, tracking, and counting within user-defined polygonal or\n rectangular regions. It supports interactive region manipulation, multiple counting areas, and both live viewing\n and video saving capabilities.\n\n Args:\n weights (str): Path to the YOLO model weights file.\n source (str): Path to the input video file.\n device (str): Processing device specification ('cpu', '0', '1', etc.).\n view_img (bool): Display results in a live window.\n save_img (bool): Save processed video to file.\n exist_ok (bool): Overwrite existing output files without incrementing.\n classes (List[int], optional): Specific class IDs to detect and track.\n line_thickness (int): Thickness of bounding box lines.\n track_thickness (int): Thickness of object tracking lines.\n region_thickness (int): Thickness of counting region boundaries.\n\n Examples:\n Run region counting with default settings\n >>> run(source=\"video.mp4\", view_img=True)\n\n Run with custom model and specific classes\n >>> run(weights=\"yolo11s.pt\", source=\"traffic.mp4\", classes=[0, 2, 3], device=\"0\")\n \"\"\"\n vid_frame_count = 0\n\n # Check source path\n if not Path(source).exists():\n raise FileNotFoundError(f\"Source path '{source}' does not exist.\")\n\n # Setup Model\n model = YOLO(f\"{weights}\")\n model.to(\"cuda\") if device == \"0\" else model.to(\"cpu\")\n\n # Extract classes names\n names = model.names\n\n # Video setup\n videocapture = cv2.VideoCapture(source)\n frame_width = int(videocapture.get(3))\n frame_height = int(videocapture.get(4))\n fps = int(videocapture.get(5))\n fourcc = cv2.VideoWriter_fourcc(*\"mp4v\")\n\n # Output setup\n save_dir = increment_path(Path(\"ultralytics_rc_output\") / \"exp\", exist_ok)\n save_dir.mkdir(parents=True, exist_ok=True)\n video_writer = cv2.VideoWriter(str(save_dir / f\"{Path(source).stem}.avi\"), fourcc, fps, (frame_width, frame_height))\n\n # Iterate over video frames\n while videocapture.isOpened():\n success, frame = videocapture.read()\n if not success:\n break\n vid_frame_count += 1\n\n # Extract the results\n results = model.track(frame, persist=True, classes=classes)\n\n if results[0].boxes.is_track:\n boxes = results[0].boxes.xyxy.cpu()\n track_ids = results[0].boxes.id.int().cpu().tolist()\n clss = results[0].boxes.cls.cpu().tolist()\n\n annotator = Annotator(frame, line_width=line_thickness, example=str(names))\n\n for box, track_id, cls in zip(boxes, track_ids, clss):\n annotator.box_label(box, str(names[cls]), color=colors(cls, True))\n bbox_center = (box[0] + box[2]) / 2, (box[1] + box[3]) / 2 # Bbox center\n\n track = track_history[track_id] # Tracking Lines plot\n track.append((float(bbox_center[0]), float(bbox_center[1])))\n if len(track) > 30:\n track.pop(0)\n points = np.hstack(track).astype(np.int32).reshape((-1, 1, 2))\n cv2.polylines(frame, [points], isClosed=False, color=colors(cls, True), thickness=track_thickness)\n\n # Check if detection inside region\n for region in counting_regions:\n if region[\"polygon\"].contains(Point((bbox_center[0], bbox_center[1]))):\n region[\"counts\"] += 1\n\n # Draw regions (Polygons/Rectangles)\n for region in counting_regions:\n region_label = str(region[\"counts\"])\n region_color = region[\"region_color\"]\n region_text_color = region[\"text_color\"]\n\n polygon_coordinates = np.array(region[\"polygon\"].exterior.coords, dtype=np.int32)\n centroid_x, centroid_y = int(region[\"polygon\"].centroid.x), int(region[\"polygon\"].centroid.y)\n\n text_size, _ = cv2.getTextSize(\n region_label, cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.7, thickness=line_thickness\n )\n text_x = centroid_x - text_size[0] // 2\n text_y = centroid_y + text_size[1] // 2\n cv2.rectangle(\n frame,\n (text_x - 5, text_y - text_size[1] - 5),\n (text_x + text_size[0] + 5, text_y + 5),\n region_color,\n -1,\n )\n cv2.putText(\n frame, region_label, (text_x, text_y), cv2.FONT_HERSHEY_SIMPLEX, 0.7, region_text_color, line_thickness\n )\n cv2.polylines(frame, [polygon_coordinates], isClosed=True, color=region_color, thickness=region_thickness)\n\n if view_img:\n if vid_frame_count == 1:\n cv2.namedWindow(\"Ultralytics YOLO Region Counter Movable\")\n cv2.setMouseCallback(\"Ultralytics YOLO Region Counter Movable\", mouse_callback)\n cv2.imshow(\"Ultralytics YOLO Region Counter Movable\", frame)\n\n if save_img:\n video_writer.write(frame)\n\n for region in counting_regions: # Reinitialize count for each region\n region[\"counts\"] = 0\n\n if cv2.waitKey(1) & 0xFF == ord(\"q\"):\n break\n\n del vid_frame_count\n video_writer.release()\n videocapture.release()\n cv2.destroyAllWindows()", "chunk_type": "function", "name": "run", "file_path": "ultralytics\\examples\\YOLOv8-Region-Counter\\yolov8_region_counter.py", "start_line": 86, "end_line": 226, "start_col": 0, "end_col": 27, "parent_name": null, "docstring": "Run object detection and counting within specified regions using YOLO and ByteTrack.\n\nThis function performs real-time object detection, tracking, and counting within user-defined polygonal or\nrectangular regions. It supports interactive region manipulation, multiple counting areas, and both live viewing\nand video saving capabilities.\n\nArgs:\n weights (str): Path to the YOLO model weights file.\n source (str): Path to the input video file.\n device (str): Processing device specification ('cpu', '0', '1', etc.).\n view_img (bool): Display results in a live window.\n save_img (bool): Save processed video to file.\n exist_ok (bool): Overwrite existing output files without incrementing.\n classes (List[int], optional): Specific class IDs to detect and track.\n line_thickness (int): Thickness of bounding box lines.\n track_thickness (int): Thickness of object tracking lines.\n region_thickness (int): Thickness of counting region boundaries.\n\nExamples:\n Run region counting with default settings\n >>> run(source=\"video.mp4\", view_img=True)\n\n Run with custom model and specific classes\n >>> run(weights=\"yolo11s.pt\", source=\"traffic.mp4\", classes=[0, 2, 3], device=\"0\")", "parameters": [ "weights: str", "source: str", "device: str", "view_img: bool", "save_img: bool", "exist_ok: bool", "classes: List[int]", "line_thickness: int", "track_thickness: int", "region_thickness: int" ], "return_type": "None", "decorators": [], "complexity_score": 15, "dependencies": [ "argparse", "collections.defaultdict", "pathlib.Path", "typing.Any", "typing.List", "cv2", "numpy", "shapely.geometry.Polygon", "shapely.geometry.point.Point", "ultralytics.YOLO", "ultralytics.utils.files.increment_path", "ultralytics.utils.plotting.Annotator", "ultralytics.utils.plotting.colors" ], "chunk_id": "function_run_41bc630d" }, { "content": "def parse_opt() -> argparse.Namespace:\n \"\"\"Parse command line arguments for the region counting application.\"\"\"\n parser = argparse.ArgumentParser()\n parser.add_argument(\"--weights\", type=str, default=\"yolo11n.pt\", help=\"initial weights path\")\n parser.add_argument(\"--device\", default=\"\", help=\"cuda device, i.e. 0 or 0,1,2,3 or cpu\")\n parser.add_argument(\"--source\", type=str, required=True, help=\"video file path\")\n parser.add_argument(\"--view-img\", action=\"store_true\", help=\"show results\")\n parser.add_argument(\"--save-img\", action=\"store_true\", help=\"save results\")\n parser.add_argument(\"--exist-ok\", action=\"store_true\", help=\"existing project/name ok, do not increment\")\n parser.add_argument(\"--classes\", nargs=\"+\", type=int, help=\"filter by class: --classes 0, or --classes 0 2 3\")\n parser.add_argument(\"--line-thickness\", type=int, default=2, help=\"bounding box thickness\")\n parser.add_argument(\"--track-thickness\", type=int, default=2, help=\"Tracking line thickness\")\n parser.add_argument(\"--region-thickness\", type=int, default=4, help=\"Region thickness\")\n\n return parser.parse_args()", "chunk_type": "function", "name": "parse_opt", "file_path": "ultralytics\\examples\\YOLOv8-Region-Counter\\yolov8_region_counter.py", "start_line": 229, "end_line": 243, "start_col": 0, "end_col": 30, "parent_name": null, "docstring": "Parse command line arguments for the region counting application.", "parameters": [], "return_type": "argparse.Namespace", "decorators": [], "complexity_score": 1, "dependencies": [ "argparse", "collections.defaultdict", "pathlib.Path", "typing.Any", "typing.List", "cv2", "numpy", "shapely.geometry.Polygon", "shapely.geometry.point.Point", "ultralytics.YOLO", "ultralytics.utils.files.increment_path", "ultralytics.utils.plotting.Annotator", "ultralytics.utils.plotting.colors" ], "chunk_id": "function_parse_opt_4aff2b57" }, { "content": "def main(options: argparse.Namespace) -> None:\n \"\"\"Execute the main region counting functionality with the provided options.\"\"\"\n run(**vars(options))", "chunk_type": "function", "name": "main", "file_path": "ultralytics\\examples\\YOLOv8-Region-Counter\\yolov8_region_counter.py", "start_line": 246, "end_line": 248, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": "Execute the main region counting functionality with the provided options.", "parameters": [ "options: argparse.Namespace" ], "return_type": "None", "decorators": [], "complexity_score": 1, "dependencies": [ "argparse", "collections.defaultdict", "pathlib.Path", "typing.Any", "typing.List", "cv2", "numpy", "shapely.geometry.Polygon", "shapely.geometry.point.Point", "ultralytics.YOLO", "ultralytics.utils.files.increment_path", "ultralytics.utils.plotting.Annotator", "ultralytics.utils.plotting.colors" ], "chunk_id": "function_main_843a60e6" }, { "content": "import argparse", "chunk_type": "import", "name": "argparse", "file_path": "ultralytics\\examples\\YOLOv8-SAHI-Inference-Video\\yolov8_sahi.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 15, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_argparse_ec0f32ac" }, { "content": "import cv2", "chunk_type": "import", "name": "cv2", "file_path": "ultralytics\\examples\\YOLOv8-SAHI-Inference-Video\\yolov8_sahi.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 10, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_cv2_0ea77e68" }, { "content": "from sahi import AutoDetectionModel", "chunk_type": "import", "name": "AutoDetectionModel", "file_path": "ultralytics\\examples\\YOLOv8-SAHI-Inference-Video\\yolov8_sahi.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 35, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_AutoDetectionModel_0202a377" }, { "content": "from sahi.predict import get_sliced_prediction", "chunk_type": "import", "name": "get_sliced_prediction", "file_path": "ultralytics\\examples\\YOLOv8-SAHI-Inference-Video\\yolov8_sahi.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 46, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_get_sliced_prediction_68ec8a61" }, { "content": "from sahi.utils.ultralytics import download_model_weights", "chunk_type": "import", "name": "download_model_weights", "file_path": "ultralytics\\examples\\YOLOv8-SAHI-Inference-Video\\yolov8_sahi.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 57, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_download_model_weights_65bc9a0f" }, { "content": "from ultralytics.utils.files import increment_path", "chunk_type": "import", "name": "increment_path", "file_path": "ultralytics\\examples\\YOLOv8-SAHI-Inference-Video\\yolov8_sahi.py", "start_line": 10, "end_line": 10, "start_col": 0, "end_col": 50, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_increment_path_447c55b9" }, { "content": "class SAHIInference:\n \"\"\"\n Runs Ultralytics YOLO11 and SAHI for object detection on video with options to view, save, and track results.\n\n This class integrates SAHI (Slicing Aided Hyper Inference) with YOLO11 models to perform efficient object detection\n on large images by slicing them into smaller pieces, running inference on each slice, and then merging the results.\n\n Attributes:\n detection_model (AutoDetectionModel): The loaded YOLO11 model wrapped with SAHI functionality.\n\n Methods:\n load_model: Load a YOLO11 model with specified weights for object detection using SAHI.\n inference: Run object detection on a video using YOLO11 and SAHI.\n parse_opt: Parse command line arguments for the inference process.\n\n Examples:\n Initialize and run SAHI inference on a video\n >>> sahi_inference = SAHIInference()\n >>> sahi_inference.inference(weights=\"yolo11n.pt\", source=\"video.mp4\", view_img=True)\n \"\"\"\n\n def __init__(self):\n \"\"\"Initialize the SAHIInference class for performing sliced inference using SAHI with YOLO11 models.\"\"\"\n self.detection_model = None\n\n def load_model(self, weights: str, device: str) -> None:\n \"\"\"\n Load a YOLO11 model with specified weights for object detection using SAHI.\n\n Args:\n weights (str): Path to the model weights file.\n device (str): CUDA device, i.e., '0' or '0,1,2,3' or 'cpu'.\n \"\"\"\n from ultralytics.utils.torch_utils import select_device\n\n yolo11_model_path = f\"models/{weights}\"\n download_model_weights(yolo11_model_path) # Download model if not present\n self.detection_model = AutoDetectionModel.from_pretrained(\n model_type=\"ultralytics\", model_path=yolo11_model_path, device=select_device(device)\n )\n\n def inference(\n self,\n weights: str = \"yolo11n.pt\",\n source: str = \"test.mp4\",\n view_img: bool = False,\n save_img: bool = False,\n exist_ok: bool = False,\n device: str = \"\",\n hide_conf: bool = False,\n slice_width: int = 512,\n slice_height: int = 512,\n ) -> None:\n \"\"\"\n Run object detection on a video using YOLO11 and SAHI.\n\n The function processes each frame of the video, applies sliced inference using SAHI,\n and optionally displays and/or saves the results with bounding boxes and labels.\n\n Args:\n weights (str): Model weights' path.\n source (str): Video file path.\n view_img (bool): Whether to display results in a window.\n save_img (bool): Whether to save results to a video file.\n exist_ok (bool): Whether to overwrite existing output files.\n device (str, optional): CUDA device, i.e., '0' or '0,1,2,3' or 'cpu'.\n hide_conf (bool, optional): Flag to show or hide confidences in the output.\n slice_width (int, optional): Slice width for inference.\n slice_height (int, optional): Slice height for inference.\n \"\"\"\n # Video setup\n cap = cv2.VideoCapture(source)\n assert cap.isOpened(), \"Error reading video file\"\n\n # Output setup\n save_dir = increment_path(\"runs/detect/predict\", exist_ok)\n save_dir.mkdir(parents=True, exist_ok=True)\n\n # Load model\n self.load_model(weights, device)\n idx = 0 # Index for image frame writing\n while cap.isOpened():\n success, frame = cap.read()\n if not success:\n break\n\n # Perform sliced prediction using SAHI\n results = get_sliced_prediction(\n frame[..., ::-1], # Convert BGR to RGB\n self.detection_model,\n slice_height=slice_height,\n slice_width=slice_width,\n )\n\n # Display results if requested\n if view_img:\n cv2.imshow(\"Ultralytics YOLO Inference\", frame)\n\n # Save results if requested\n if save_img:\n idx += 1\n results.export_visuals(export_dir=save_dir, file_name=f\"img_{idx}\", hide_conf=hide_conf)\n\n # Break loop if 'q' is pressed\n if cv2.waitKey(1) & 0xFF == ord(\"q\"):\n break\n\n # Clean up resources\n cap.release()\n cv2.destroyAllWindows()\n\n @staticmethod\n def parse_opt() -> argparse.Namespace:\n \"\"\"\n Parse command line arguments for the inference process.\n\n Returns:\n (argparse.Namespace): Parsed command line arguments.\n \"\"\"\n parser = argparse.ArgumentParser()\n parser.add_argument(\"--weights\", type=str, default=\"yolo11n.pt\", help=\"initial weights path\")\n parser.add_argument(\"--source\", type=str, required=True, help=\"video file path\")\n parser.add_argument(\"--view-img\", action=\"store_true\", help=\"show results\")\n parser.add_argument(\"--save-img\", action=\"store_true\", help=\"save results\")\n parser.add_argument(\"--exist-ok\", action=\"store_true\", help=\"existing project/name ok, do not increment\")\n parser.add_argument(\"--device\", default=\"\", help=\"cuda device, i.e. 0 or 0,1,2,3 or cpu\")\n parser.add_argument(\"--hide-conf\", default=False, action=\"store_true\", help=\"display or hide confidences\")\n parser.add_argument(\"--slice-width\", default=512, type=int, help=\"Slice width for inference\")\n parser.add_argument(\"--slice-height\", default=512, type=int, help=\"Slice height for inference\")\n return parser.parse_args()", "chunk_type": "class", "name": "SAHIInference", "file_path": "ultralytics\\examples\\YOLOv8-SAHI-Inference-Video\\yolov8_sahi.py", "start_line": 13, "end_line": 142, "start_col": 0, "end_col": 34, "parent_name": null, "docstring": "Runs Ultralytics YOLO11 and SAHI for object detection on video with options to view, save, and track results.\n\nThis class integrates SAHI (Slicing Aided Hyper Inference) with YOLO11 models to perform efficient object detection\non large images by slicing them into smaller pieces, running inference on each slice, and then merging the results.\n\nAttributes:\n detection_model (AutoDetectionModel): The loaded YOLO11 model wrapped with SAHI functionality.\n\nMethods:\n load_model: Load a YOLO11 model with specified weights for object detection using SAHI.\n inference: Run object detection on a video using YOLO11 and SAHI.\n parse_opt: Parse command line arguments for the inference process.\n\nExamples:\n Initialize and run SAHI inference on a video\n >>> sahi_inference = SAHIInference()\n >>> sahi_inference.inference(weights=\"yolo11n.pt\", source=\"video.mp4\", view_img=True)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "argparse", "cv2", "sahi.AutoDetectionModel", "sahi.predict.get_sliced_prediction", "sahi.utils.ultralytics.download_model_weights", "ultralytics.utils.files.increment_path", "ultralytics.utils.torch_utils.select_device" ], "chunk_id": "class_SAHIInference_a9f9f63f" }, { "content": "import argparse", "chunk_type": "import", "name": "argparse", "file_path": "ultralytics\\examples\\YOLOv8-Segmentation-ONNXRuntime-Python\\main.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 15, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_argparse_abd60e5a" }, { "content": "from typing import List, Tuple, Union", "chunk_type": "import", "name": "List, Tuple, Union", "file_path": "ultralytics\\examples\\YOLOv8-Segmentation-ONNXRuntime-Python\\main.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 37, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_List, Tuple, Union_bd286055" }, { "content": "import cv2", "chunk_type": "import", "name": "cv2", "file_path": "ultralytics\\examples\\YOLOv8-Segmentation-ONNXRuntime-Python\\main.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 10, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_cv2_2f016c2f" }, { "content": "import numpy as np", "chunk_type": "import", "name": "numpy", "file_path": "ultralytics\\examples\\YOLOv8-Segmentation-ONNXRuntime-Python\\main.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_numpy_b22d08cf" }, { "content": "import onnxruntime as ort", "chunk_type": "import", "name": "onnxruntime", "file_path": "ultralytics\\examples\\YOLOv8-Segmentation-ONNXRuntime-Python\\main.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 25, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_onnxruntime_aa5bced0" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\examples\\YOLOv8-Segmentation-ONNXRuntime-Python\\main.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_f58c15f2" }, { "content": "import ultralytics.utils.ops as ops", "chunk_type": "import", "name": "ultralytics.utils.ops", "file_path": "ultralytics\\examples\\YOLOv8-Segmentation-ONNXRuntime-Python\\main.py", "start_line": 11, "end_line": 11, "start_col": 0, "end_col": 35, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_ultralytics.utils.ops_fd2ef511" }, { "content": "from ultralytics.engine.results import Results", "chunk_type": "import", "name": "Results", "file_path": "ultralytics\\examples\\YOLOv8-Segmentation-ONNXRuntime-Python\\main.py", "start_line": 12, "end_line": 12, "start_col": 0, "end_col": 46, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Results_b65731d4" }, { "content": "from ultralytics.utils import ASSETS, YAML", "chunk_type": "import", "name": "ASSETS, YAML", "file_path": "ultralytics\\examples\\YOLOv8-Segmentation-ONNXRuntime-Python\\main.py", "start_line": 13, "end_line": 13, "start_col": 0, "end_col": 42, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_ASSETS, YAML_03353577" }, { "content": "from ultralytics.utils.checks import check_yaml", "chunk_type": "import", "name": "check_yaml", "file_path": "ultralytics\\examples\\YOLOv8-Segmentation-ONNXRuntime-Python\\main.py", "start_line": 14, "end_line": 14, "start_col": 0, "end_col": 47, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_check_yaml_34fdf91f" }, { "content": "class YOLOv8Seg:\n \"\"\"\n YOLOv8 segmentation model for performing instance segmentation using ONNX Runtime.\n\n This class implements a YOLOv8 instance segmentation model using ONNX Runtime for inference. It handles\n preprocessing of input images, running inference with the ONNX model, and postprocessing the results to\n generate bounding boxes and segmentation masks.\n\n Attributes:\n session (ort.InferenceSession): ONNX Runtime inference session for model execution.\n imgsz (Tuple[int, int]): Input image size as (height, width) for the model.\n classes (dict): Dictionary mapping class indices to class names from the dataset.\n conf (float): Confidence threshold for filtering detections.\n iou (float): IoU threshold used by non-maximum suppression.\n\n Methods:\n letterbox: Resize and pad image while maintaining aspect ratio.\n preprocess: Preprocess the input image before feeding it into the model.\n postprocess: Post-process model predictions to extract meaningful results.\n process_mask: Process prototype masks with predicted mask coefficients to generate instance segmentation masks.\n\n Examples:\n >>> model = YOLOv8Seg(\"yolov8n-seg.onnx\", conf=0.25, iou=0.7)\n >>> img = cv2.imread(\"image.jpg\")\n >>> results = model(img)\n >>> cv2.imshow(\"Segmentation\", results[0].plot())\n \"\"\"\n\n def __init__(self, onnx_model: str, conf: float = 0.25, iou: float = 0.7, imgsz: Union[int, Tuple[int, int]] = 640):\n \"\"\"\n Initialize the instance segmentation model using an ONNX model.\n\n Args:\n onnx_model (str): Path to the ONNX model file.\n conf (float, optional): Confidence threshold for filtering detections.\n iou (float, optional): IoU threshold for non-maximum suppression.\n imgsz (int | Tuple[int, int], optional): Input image size of the model. Can be an integer for square\n input or a tuple for rectangular input.\n \"\"\"\n self.session = ort.InferenceSession(\n onnx_model,\n providers=[\"CUDAExecutionProvider\", \"CPUExecutionProvider\"]\n if torch.cuda.is_available()\n else [\"CPUExecutionProvider\"],\n )\n\n self.imgsz = (imgsz, imgsz) if isinstance(imgsz, int) else imgsz\n self.classes = YAML.load(check_yaml(\"coco8.yaml\"))[\"names\"]\n self.conf = conf\n self.iou = iou\n\n def __call__(self, img: np.ndarray) -> List[Results]:\n \"\"\"\n Run inference on the input image using the ONNX model.\n\n Args:\n img (np.ndarray): The original input image in BGR format.\n\n Returns:\n (List[Results]): Processed detection results after post-processing, containing bounding boxes and\n segmentation masks.\n \"\"\"\n prep_img = self.preprocess(img, self.imgsz)\n outs = self.session.run(None, {self.session.get_inputs()[0].name: prep_img})\n return self.postprocess(img, prep_img, outs)\n\n def letterbox(self, img: np.ndarray, new_shape: Tuple[int, int] = (640, 640)) -> np.ndarray:\n \"\"\"\n Resize and pad image while maintaining aspect ratio.\n\n Args:\n img (np.ndarray): Input image in BGR format.\n new_shape (Tuple[int, int], optional): Target shape as (height, width).\n\n Returns:\n (np.ndarray): Resized and padded image.\n \"\"\"\n shape = img.shape[:2] # current shape [height, width]\n\n # Scale ratio (new / old)\n r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])\n\n # Compute padding\n new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))\n dw, dh = (new_shape[1] - new_unpad[0]) / 2, (new_shape[0] - new_unpad[1]) / 2 # wh padding\n\n if shape[::-1] != new_unpad: # resize\n img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)\n top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))\n left, right = int(round(dw - 0.1)), int(round(dw + 0.1))\n img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114))\n\n return img\n\n def preprocess(self, img: np.ndarray, new_shape: Tuple[int, int]) -> np.ndarray:\n \"\"\"\n Preprocess the input image before feeding it into the model.\n\n Args:\n img (np.ndarray): The input image in BGR format.\n new_shape (Tuple[int, int]): The target shape for resizing as (height, width).\n\n Returns:\n (np.ndarray): Preprocessed image ready for model inference, with shape (1, 3, height, width) and\n normalized to [0, 1].\n \"\"\"\n img = self.letterbox(img, new_shape)\n img = img[..., ::-1].transpose([2, 0, 1])[None] # BGR to RGB, BHWC to BCHW\n img = np.ascontiguousarray(img)\n img = img.astype(np.float32) / 255 # Normalize to [0, 1]\n return img\n\n def postprocess(self, img: np.ndarray, prep_img: np.ndarray, outs: List) -> List[Results]:\n \"\"\"\n Post-process model predictions to extract meaningful results.\n\n Args:\n img (np.ndarray): The original input image.\n prep_img (np.ndarray): The preprocessed image used for inference.\n outs (List): Model outputs containing predictions and prototype masks.\n\n Returns:\n (List[Results]): Processed detection results containing bounding boxes and segmentation masks.\n \"\"\"\n preds, protos = [torch.from_numpy(p) for p in outs]\n preds = ops.non_max_suppression(preds, self.conf, self.iou, nc=len(self.classes))\n\n results = []\n for i, pred in enumerate(preds):\n pred[:, :4] = ops.scale_boxes(prep_img.shape[2:], pred[:, :4], img.shape)\n masks = self.process_mask(protos[i], pred[:, 6:], pred[:, :4], img.shape[:2])\n results.append(Results(img, path=\"\", names=self.classes, boxes=pred[:, :6], masks=masks))\n\n return results\n\n def process_mask(\n self, protos: torch.Tensor, masks_in: torch.Tensor, bboxes: torch.Tensor, shape: Tuple[int, int]\n ) -> torch.Tensor:\n \"\"\"\n Process prototype masks with predicted mask coefficients to generate instance segmentation masks.\n\n Args:\n protos (torch.Tensor): Prototype masks with shape (mask_dim, mask_h, mask_w).\n masks_in (torch.Tensor): Predicted mask coefficients with shape (N, mask_dim), where N is number of\n detections.\n bboxes (torch.Tensor): Bounding boxes with shape (N, 4), where N is number of detections.\n shape (Tuple[int, int]): The size of the input image as (height, width).\n\n Returns:\n (torch.Tensor): Binary segmentation masks with shape (N, height, width).\n \"\"\"\n c, mh, mw = protos.shape # CHW\n masks = (masks_in @ protos.float().view(c, -1)).view(-1, mh, mw) # Matrix multiplication\n masks = ops.scale_masks(masks[None], shape)[0] # Scale masks to original image size\n masks = ops.crop_mask(masks, bboxes) # Crop masks to bounding boxes\n return masks.gt_(0.0) # Convert to binary masks", "chunk_type": "class", "name": "YOLOv8Seg", "file_path": "ultralytics\\examples\\YOLOv8-Segmentation-ONNXRuntime-Python\\main.py", "start_line": 17, "end_line": 172, "start_col": 0, "end_col": 29, "parent_name": null, "docstring": "YOLOv8 segmentation model for performing instance segmentation using ONNX Runtime.\n\nThis class implements a YOLOv8 instance segmentation model using ONNX Runtime for inference. It handles\npreprocessing of input images, running inference with the ONNX model, and postprocessing the results to\ngenerate bounding boxes and segmentation masks.\n\nAttributes:\n session (ort.InferenceSession): ONNX Runtime inference session for model execution.\n imgsz (Tuple[int, int]): Input image size as (height, width) for the model.\n classes (dict): Dictionary mapping class indices to class names from the dataset.\n conf (float): Confidence threshold for filtering detections.\n iou (float): IoU threshold used by non-maximum suppression.\n\nMethods:\n letterbox: Resize and pad image while maintaining aspect ratio.\n preprocess: Preprocess the input image before feeding it into the model.\n postprocess: Post-process model predictions to extract meaningful results.\n process_mask: Process prototype masks with predicted mask coefficients to generate instance segmentation masks.\n\nExamples:\n >>> model = YOLOv8Seg(\"yolov8n-seg.onnx\", conf=0.25, iou=0.7)\n >>> img = cv2.imread(\"image.jpg\")\n >>> results = model(img)\n >>> cv2.imshow(\"Segmentation\", results[0].plot())", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "argparse", "typing.List", "typing.Tuple", "typing.Union", "cv2", "numpy", "onnxruntime", "torch", "ultralytics.utils.ops", "ultralytics.engine.results.Results", "ultralytics.utils.ASSETS", "ultralytics.utils.YAML", "ultralytics.utils.checks.check_yaml" ], "chunk_id": "class_YOLOv8Seg_7e97be45" }, { "content": "import argparse", "chunk_type": "import", "name": "argparse", "file_path": "ultralytics\\examples\\YOLOv8-TFLite-Python\\main.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 15, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_argparse_ce04c982" }, { "content": "from typing import Tuple, Union", "chunk_type": "import", "name": "Tuple, Union", "file_path": "ultralytics\\examples\\YOLOv8-TFLite-Python\\main.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 31, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Tuple, Union_45e1d112" }, { "content": "import cv2", "chunk_type": "import", "name": "cv2", "file_path": "ultralytics\\examples\\YOLOv8-TFLite-Python\\main.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 10, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_cv2_fa86d4a9" }, { "content": "import numpy as np", "chunk_type": "import", "name": "numpy", "file_path": "ultralytics\\examples\\YOLOv8-TFLite-Python\\main.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_numpy_0711ccb0" }, { "content": "import yaml", "chunk_type": "import", "name": "yaml", "file_path": "ultralytics\\examples\\YOLOv8-TFLite-Python\\main.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 11, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_yaml_46c9e601" }, { "content": "from ultralytics.utils import ASSETS", "chunk_type": "import", "name": "ASSETS", "file_path": "ultralytics\\examples\\YOLOv8-TFLite-Python\\main.py", "start_line": 10, "end_line": 10, "start_col": 0, "end_col": 36, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_ASSETS_a81c9214" }, { "content": "class YOLOv8TFLite:\n \"\"\"\n A YOLOv8 object detection class using TensorFlow Lite for efficient inference.\n\n This class handles model loading, preprocessing, inference, and visualization of detection results for YOLOv8\n models converted to TensorFlow Lite format.\n\n Attributes:\n model (Interpreter): TensorFlow Lite interpreter for the YOLOv8 model.\n conf (float): Confidence threshold for filtering detections.\n iou (float): Intersection over Union threshold for non-maximum suppression.\n classes (dict): Dictionary mapping class IDs to class names.\n color_palette (np.ndarray): Random color palette for visualization with shape (num_classes, 3).\n in_width (int): Input width required by the model.\n in_height (int): Input height required by the model.\n in_index (int): Input tensor index in the model.\n in_scale (float): Input quantization scale factor.\n in_zero_point (int): Input quantization zero point.\n int8 (bool): Whether the model uses int8 quantization.\n out_index (int): Output tensor index in the model.\n out_scale (float): Output quantization scale factor.\n out_zero_point (int): Output quantization zero point.\n\n Methods:\n letterbox: Resize and pad image while maintaining aspect ratio.\n draw_detections: Draw bounding boxes and labels on the input image.\n preprocess: Preprocess the input image before inference.\n postprocess: Process model outputs to extract and visualize detections.\n detect: Perform object detection on an input image.\n\n Examples:\n Initialize detector and run inference\n >>> detector = YOLOv8TFLite(\"yolov8n.tflite\", conf=0.25, iou=0.45)\n >>> result = detector.detect(\"image.jpg\")\n >>> cv2.imshow(\"Result\", result)\n \"\"\"\n\n def __init__(self, model: str, conf: float = 0.25, iou: float = 0.45, metadata: Union[str, None] = None):\n \"\"\"\n Initialize the YOLOv8TFLite detector.\n\n Args:\n model (str): Path to the TFLite model file.\n conf (float): Confidence threshold for filtering detections.\n iou (float): IoU threshold for non-maximum suppression.\n metadata (str | None): Path to the metadata file containing class names.\n \"\"\"\n self.conf = conf\n self.iou = iou\n if metadata is None:\n self.classes = {i: i for i in range(1000)}\n else:\n with open(metadata) as f:\n self.classes = yaml.safe_load(f)[\"names\"]\n np.random.seed(42) # Set seed for reproducible colors\n self.color_palette = np.random.uniform(128, 255, size=(len(self.classes), 3))\n\n # Initialize the TFLite interpreter\n self.model = Interpreter(model_path=model)\n self.model.allocate_tensors()\n\n # Get input details\n input_details = self.model.get_input_details()[0]\n self.in_width, self.in_height = input_details[\"shape\"][1:3]\n self.in_index = input_details[\"index\"]\n self.in_scale, self.in_zero_point = input_details[\"quantization\"]\n self.int8 = input_details[\"dtype\"] == np.int8\n\n # Get output details\n output_details = self.model.get_output_details()[0]\n self.out_index = output_details[\"index\"]\n self.out_scale, self.out_zero_point = output_details[\"quantization\"]\n\n def letterbox(\n self, img: np.ndarray, new_shape: Tuple[int, int] = (640, 640)\n ) -> Tuple[np.ndarray, Tuple[float, float]]:\n \"\"\"\n Resize and pad image while maintaining aspect ratio.\n\n Args:\n img (np.ndarray): Input image with shape (H, W, C).\n new_shape (Tuple[int, int]): Target shape (height, width).\n\n Returns:\n (np.ndarray): Resized and padded image.\n (Tuple[float, float]): Padding ratios (top/height, left/width) for coordinate adjustment.\n \"\"\"\n shape = img.shape[:2] # Current shape [height, width]\n\n # Scale ratio (new / old)\n r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])\n\n # Compute padding\n new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))\n dw, dh = (new_shape[1] - new_unpad[0]) / 2, (new_shape[0] - new_unpad[1]) / 2 # wh padding\n\n if shape[::-1] != new_unpad: # Resize if needed\n img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)\n top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))\n left, right = int(round(dw - 0.1)), int(round(dw + 0.1))\n img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114))\n\n return img, (top / img.shape[0], left / img.shape[1])\n\n def draw_detections(self, img: np.ndarray, box: np.ndarray, score: np.float32, class_id: int) -> None:\n \"\"\"\n Draw bounding boxes and labels on the input image based on detected objects.\n\n Args:\n img (np.ndarray): The input image to draw detections on.\n box (np.ndarray): Detected bounding box in the format [x1, y1, width, height].\n score (np.float32): Confidence score of the detection.\n class_id (int): Class ID for the detected object.\n \"\"\"\n x1, y1, w, h = box\n color = self.color_palette[class_id]\n\n # Draw bounding box\n cv2.rectangle(img, (int(x1), int(y1)), (int(x1 + w), int(y1 + h)), color, 2)\n\n # Create label with class name and score\n label = f\"{self.classes[class_id]}: {score:.2f}\"\n\n # Get text size for background rectangle\n (label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)\n\n # Position label above or below box depending on space\n label_x = x1\n label_y = y1 - 10 if y1 - 10 > label_height else y1 + 10\n\n # Draw label background\n cv2.rectangle(\n img,\n (int(label_x), int(label_y - label_height)),\n (int(label_x + label_width), int(label_y + label_height)),\n color,\n cv2.FILLED,\n )\n\n # Draw text\n cv2.putText(img, label, (int(label_x), int(label_y)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)\n\n def preprocess(self, img: np.ndarray) -> Tuple[np.ndarray, Tuple[float, float]]:\n \"\"\"\n Preprocess the input image before performing inference.\n\n Args:\n img (np.ndarray): The input image to be preprocessed with shape (H, W, C).\n\n Returns:\n (np.ndarray): Preprocessed image ready for model input.\n (Tuple[float, float]): Padding ratios for coordinate adjustment.\n \"\"\"\n img, pad = self.letterbox(img, (self.in_width, self.in_height))\n img = img[..., ::-1][None] # BGR to RGB and add batch dimension (N, H, W, C) for TFLite\n img = np.ascontiguousarray(img)\n img = img.astype(np.float32)\n return img / 255, pad # Normalize to [0, 1]\n\n def postprocess(self, img: np.ndarray, outputs: np.ndarray, pad: Tuple[float, float]) -> np.ndarray:\n \"\"\"\n Process model outputs to extract and visualize detections.\n\n Args:\n img (np.ndarray): The original input image.\n outputs (np.ndarray): Raw model outputs.\n pad (Tuple[float, float]): Padding ratios from preprocessing.\n\n Returns:\n (np.ndarray): The input image with detections drawn on it.\n \"\"\"\n # Adjust coordinates based on padding and scale to original image size\n outputs[:, 0] -= pad[1]\n outputs[:, 1] -= pad[0]\n outputs[:, :4] *= max(img.shape)\n\n # Transform outputs to [x, y, w, h] format\n outputs = outputs.transpose(0, 2, 1)\n outputs[..., 0] -= outputs[..., 2] / 2 # x center to top-left x\n outputs[..., 1] -= outputs[..., 3] / 2 # y center to top-left y\n\n for out in outputs:\n # Get scores and apply confidence threshold\n scores = out[:, 4:].max(-1)\n keep = scores > self.conf\n boxes = out[keep, :4]\n scores = scores[keep]\n class_ids = out[keep, 4:].argmax(-1)\n\n # Apply non-maximum suppression\n indices = cv2.dnn.NMSBoxes(boxes, scores, self.conf, self.iou).flatten()\n\n # Draw detections that survived NMS\n [self.draw_detections(img, boxes[i], scores[i], class_ids[i]) for i in indices]\n\n return img\n\n def detect(self, img_path: str) -> np.ndarray:\n \"\"\"\n Perform object detection on an input image.\n\n Args:\n img_path (str): Path to the input image file.\n\n Returns:\n (np.ndarray): The output image with drawn detections.\n \"\"\"\n # Load and preprocess image\n img = cv2.imread(img_path)\n x, pad = self.preprocess(img)\n\n # Apply quantization if model is int8\n if self.int8:\n x = (x / self.in_scale + self.in_zero_point).astype(np.int8)\n\n # Set input tensor and run inference\n self.model.set_tensor(self.in_index, x)\n self.model.invoke()\n\n # Get output and dequantize if necessary\n y = self.model.get_tensor(self.out_index)\n if self.int8:\n y = (y.astype(np.float32) - self.out_zero_point) * self.out_scale\n\n # Process detections and return result\n return self.postprocess(img, y, pad)", "chunk_type": "class", "name": "YOLOv8TFLite", "file_path": "ultralytics\\examples\\YOLOv8-TFLite-Python\\main.py", "start_line": 20, "end_line": 245, "start_col": 0, "end_col": 44, "parent_name": null, "docstring": "A YOLOv8 object detection class using TensorFlow Lite for efficient inference.\n\nThis class handles model loading, preprocessing, inference, and visualization of detection results for YOLOv8\nmodels converted to TensorFlow Lite format.\n\nAttributes:\n model (Interpreter): TensorFlow Lite interpreter for the YOLOv8 model.\n conf (float): Confidence threshold for filtering detections.\n iou (float): Intersection over Union threshold for non-maximum suppression.\n classes (dict): Dictionary mapping class IDs to class names.\n color_palette (np.ndarray): Random color palette for visualization with shape (num_classes, 3).\n in_width (int): Input width required by the model.\n in_height (int): Input height required by the model.\n in_index (int): Input tensor index in the model.\n in_scale (float): Input quantization scale factor.\n in_zero_point (int): Input quantization zero point.\n int8 (bool): Whether the model uses int8 quantization.\n out_index (int): Output tensor index in the model.\n out_scale (float): Output quantization scale factor.\n out_zero_point (int): Output quantization zero point.\n\nMethods:\n letterbox: Resize and pad image while maintaining aspect ratio.\n draw_detections: Draw bounding boxes and labels on the input image.\n preprocess: Preprocess the input image before inference.\n postprocess: Process model outputs to extract and visualize detections.\n detect: Perform object detection on an input image.\n\nExamples:\n Initialize detector and run inference\n >>> detector = YOLOv8TFLite(\"yolov8n.tflite\", conf=0.25, iou=0.45)\n >>> result = detector.detect(\"image.jpg\")\n >>> cv2.imshow(\"Result\", result)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "argparse", "typing.Tuple", "typing.Union", "cv2", "numpy", "yaml", "ultralytics.utils.ASSETS", "tflite_runtime.interpreter.Interpreter", "tensorflow" ], "chunk_id": "class_YOLOv8TFLite_0ecd01e0" }, { "content": "import shutil", "chunk_type": "import", "name": "shutil", "file_path": "ultralytics\\ultralytics\\cfg\\__init__.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 13, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_shutil_301995cf" }, { "content": "import subprocess", "chunk_type": "import", "name": "subprocess", "file_path": "ultralytics\\ultralytics\\cfg\\__init__.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 17, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_subprocess_167ef506" }, { "content": "import sys", "chunk_type": "import", "name": "sys", "file_path": "ultralytics\\ultralytics\\cfg\\__init__.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 10, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_sys_aa9f3ede" }, { "content": "from pathlib import Path", "chunk_type": "import", "name": "Path", "file_path": "ultralytics\\ultralytics\\cfg\\__init__.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Path_181adb73" }, { "content": "from types import SimpleNamespace", "chunk_type": "import", "name": "SimpleNamespace", "file_path": "ultralytics\\ultralytics\\cfg\\__init__.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 33, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_SimpleNamespace_37b14440" }, { "content": "from typing import Any, Dict, List, Union", "chunk_type": "import", "name": "Any, Dict, List, Union", "file_path": "ultralytics\\ultralytics\\cfg\\__init__.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 41, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Any, Dict, List, Union_c06c07bb" }, { "content": "from ultralytics import __version__", "chunk_type": "import", "name": "__version__", "file_path": "ultralytics\\ultralytics\\cfg\\__init__.py", "start_line": 10, "end_line": 10, "start_col": 0, "end_col": 35, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import___version___e50c687b" }, { "content": "from ultralytics.utils import (\n ASSETS,\n DEFAULT_CFG,\n DEFAULT_CFG_DICT,\n DEFAULT_CFG_PATH,\n IS_VSCODE,\n LOGGER,\n RANK,\n ROOT,\n RUNS_DIR,\n SETTINGS,\n SETTINGS_FILE,\n TESTS_RUNNING,\n YAML,\n IterableSimpleNamespace,\n checks,\n colorstr,\n deprecation_warn,\n vscode_msg,\n)", "chunk_type": "import", "name": "ASSETS, DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_PATH, IS_VSCODE, LOGGER, RANK, ROOT, RUNS_DIR, SETTINGS, SETTINGS_FILE, TESTS_RUNNING, YAML, IterableSimpleNamespace, checks, colorstr, deprecation_warn, vscode_msg", "file_path": "ultralytics\\ultralytics\\cfg\\__init__.py", "start_line": 11, "end_line": 30, "start_col": 0, "end_col": 1, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_ASSETS, DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_PATH, IS_VSCODE, LOGGER, RANK, ROOT, RUNS_DIR, SETTINGS, SETTINGS_FILE, TESTS_RUNNING, YAML, IterableSimpleNamespace, checks, colorstr, deprecation_warn, vscode_msg_a842cdb4" }, { "content": "SOLUTION_MAP = {\n \"count\": \"ObjectCounter\",\n \"crop\": \"ObjectCropper\",\n \"blur\": \"ObjectBlurrer\",\n \"workout\": \"AIGym\",\n \"heatmap\": \"Heatmap\",\n \"isegment\": \"InstanceSegmentation\",\n \"visioneye\": \"VisionEye\",\n \"speed\": \"SpeedEstimator\",\n \"queue\": \"QueueManager\",\n \"analytics\": \"Analytics\",\n \"inference\": \"Inference\",\n \"trackzone\": \"TrackZone\",\n \"help\": None,\n}", "chunk_type": "variable", "name": "SOLUTION_MAP", "file_path": "ultralytics\\ultralytics\\cfg\\__init__.py", "start_line": 33, "end_line": 47, "start_col": 0, "end_col": 1, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_SOLUTION_MAP_c64003d7" }, { "content": "MODES = frozenset({\"train\", \"val\", \"predict\", \"export\", \"track\", \"benchmark\"})", "chunk_type": "variable", "name": "MODES", "file_path": "ultralytics\\ultralytics\\cfg\\__init__.py", "start_line": 50, "end_line": 50, "start_col": 0, "end_col": 78, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_MODES_64d58f60" }, { "content": "TASKS = frozenset({\"detect\", \"segment\", \"classify\", \"pose\", \"obb\"})", "chunk_type": "variable", "name": "TASKS", "file_path": "ultralytics\\ultralytics\\cfg\\__init__.py", "start_line": 51, "end_line": 51, "start_col": 0, "end_col": 67, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_TASKS_84578197" }, { "content": "TASK2DATA = {\n \"detect\": \"coco8.yaml\",\n \"segment\": \"coco8-seg.yaml\",\n \"classify\": \"imagenet10\",\n \"pose\": \"coco8-pose.yaml\",\n \"obb\": \"dota8.yaml\",\n}", "chunk_type": "variable", "name": "TASK2DATA", "file_path": "ultralytics\\ultralytics\\cfg\\__init__.py", "start_line": 52, "end_line": 58, "start_col": 0, "end_col": 1, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_TASK2DATA_4857dec3" }, { "content": "TASK2MODEL = {\n \"detect\": \"yolo11n.pt\",\n \"segment\": \"yolo11n-seg.pt\",\n \"classify\": \"yolo11n-cls.pt\",\n \"pose\": \"yolo11n-pose.pt\",\n \"obb\": \"yolo11n-obb.pt\",\n}", "chunk_type": "variable", "name": "TASK2MODEL", "file_path": "ultralytics\\ultralytics\\cfg\\__init__.py", "start_line": 59, "end_line": 65, "start_col": 0, "end_col": 1, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_TASK2MODEL_3c5da394" }, { "content": "TASK2METRIC = {\n \"detect\": \"metrics/mAP50-95(B)\",\n \"segment\": \"metrics/mAP50-95(M)\",\n \"classify\": \"metrics/accuracy_top1\",\n \"pose\": \"metrics/mAP50-95(P)\",\n \"obb\": \"metrics/mAP50-95(B)\",\n}", "chunk_type": "variable", "name": "TASK2METRIC", "file_path": "ultralytics\\ultralytics\\cfg\\__init__.py", "start_line": 66, "end_line": 72, "start_col": 0, "end_col": 1, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_TASK2METRIC_ccf12afb" }, { "content": "ARGV = sys.argv or [\"\", \"\"] # sometimes sys.argv = []", "chunk_type": "variable", "name": "ARGV", "file_path": "ultralytics\\ultralytics\\cfg\\__init__.py", "start_line": 74, "end_line": 74, "start_col": 0, "end_col": 27, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_ARGV_d48bb83f" }, { "content": "SOLUTIONS_HELP_MSG = f\"\"\"\n Arguments received: {str([\"yolo\"] + ARGV[1:])}. Ultralytics 'yolo solutions' usage overview:\n\n yolo solutions SOLUTION ARGS\n\n Where SOLUTION (optional) is one of {list(SOLUTION_MAP.keys())[:-1]}\n ARGS (optional) are any number of custom 'arg=value' pairs like 'show_in=True' that override defaults \n at https://docs.ultralytics.com/usage/cfg\n \n 1. Call object counting solution\n yolo solutions count source=\"path/to/video.mp4\" region=\"[(20, 400), (1080, 400), (1080, 360), (20, 360)]\"\n\n 2. Call heatmaps solution\n yolo solutions heatmap colormap=cv2.COLORMAP_PARULA model=yolo11n.pt\n\n 3. Call queue management solution\n yolo solutions queue region=\"[(20, 400), (1080, 400), (1080, 360), (20, 360)]\" model=yolo11n.pt\n\n 4. Call workouts monitoring solution for push-ups\n yolo solutions workout model=yolo11n-pose.pt kpts=[6, 8, 10]\n\n 5. Generate analytical graphs\n yolo solutions analytics analytics_type=\"pie\"\n \n 6. Track objects within specific zones\n yolo solutions trackzone source=\"path/to/video.mp4\" region=\"[(150, 150), (1130, 150), (1130, 570), (150, 570)]\"\n \n 7. Streamlit real-time webcam inference GUI\n yolo streamlit-predict\n \"\"\"", "chunk_type": "variable", "name": "SOLUTIONS_HELP_MSG", "file_path": "ultralytics\\ultralytics\\cfg\\__init__.py", "start_line": 75, "end_line": 104, "start_col": 0, "end_col": 7, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_SOLUTIONS_HELP_MSG_916b8677" }, { "content": "CLI_HELP_MSG = f\"\"\"\n Arguments received: {str([\"yolo\"] + ARGV[1:])}. Ultralytics 'yolo' commands use the following syntax:\n\n yolo TASK MODE ARGS\n\n Where TASK (optional) is one of {list(TASKS)}\n MODE (required) is one of {list(MODES)}\n ARGS (optional) are any number of custom 'arg=value' pairs like 'imgsz=320' that override defaults.\n See all ARGS at https://docs.ultralytics.com/usage/cfg or with 'yolo cfg'\n\n 1. Train a detection model for 10 epochs with an initial learning_rate of 0.01\n yolo train data=coco8.yaml model=yolo11n.pt epochs=10 lr0=0.01\n\n 2. Predict a YouTube video using a pretrained segmentation model at image size 320:\n yolo predict model=yolo11n-seg.pt source='https://youtu.be/LNwODJXcvt4' imgsz=320\n\n 3. Val a pretrained detection model at batch-size 1 and image size 640:\n yolo val model=yolo11n.pt data=coco8.yaml batch=1 imgsz=640\n\n 4. Export a YOLO11n classification model to ONNX format at image size 224 by 128 (no TASK required)\n yolo export model=yolo11n-cls.pt format=onnx imgsz=224,128\n\n 5. Ultralytics solutions usage\n yolo solutions count or in {list(SOLUTION_MAP.keys())[1:-1]} source=\"path/to/video.mp4\"\n\n 6. Run special commands:\n yolo help\n yolo checks\n yolo version\n yolo settings\n yolo copy-cfg\n yolo cfg\n yolo solutions help\n\n Docs: https://docs.ultralytics.com\n Solutions: https://docs.ultralytics.com/solutions/\n Community: https://community.ultralytics.com\n GitHub: https://github.com/ultralytics/ultralytics\n \"\"\"", "chunk_type": "variable", "name": "CLI_HELP_MSG", "file_path": "ultralytics\\ultralytics\\cfg\\__init__.py", "start_line": 105, "end_line": 143, "start_col": 0, "end_col": 7, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_CLI_HELP_MSG_f1f34b46" }, { "content": "CFG_FLOAT_KEYS = frozenset(\n { # integer or float arguments, i.e. x=2 and x=2.0\n \"warmup_epochs\",\n \"box\",\n \"cls\",\n \"dfl\",\n \"degrees\",\n \"shear\",\n \"time\",\n \"workspace\",\n \"batch\",\n }\n)", "chunk_type": "variable", "name": "CFG_FLOAT_KEYS", "file_path": "ultralytics\\ultralytics\\cfg\\__init__.py", "start_line": 146, "end_line": 158, "start_col": 0, "end_col": 1, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_CFG_FLOAT_KEYS_3738c9b0" }, { "content": "CFG_FRACTION_KEYS = frozenset(\n { # fractional float arguments with 0.0<=values<=1.0\n \"dropout\",\n \"lr0\",\n \"lrf\",\n \"momentum\",\n \"weight_decay\",\n \"warmup_momentum\",\n \"warmup_bias_lr\",\n \"hsv_h\",\n \"hsv_s\",\n \"hsv_v\",\n \"translate\",\n \"scale\",\n \"perspective\",\n \"flipud\",\n \"fliplr\",\n \"bgr\",\n \"mosaic\",\n \"mixup\",\n \"cutmix\",\n \"copy_paste\",\n \"conf\",\n \"iou\",\n \"fraction\",\n }\n)", "chunk_type": "variable", "name": "CFG_FRACTION_KEYS", "file_path": "ultralytics\\ultralytics\\cfg\\__init__.py", "start_line": 159, "end_line": 185, "start_col": 0, "end_col": 1, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_CFG_FRACTION_KEYS_f7ace963" }, { "content": "CFG_INT_KEYS = frozenset(\n { # integer-only arguments\n \"epochs\",\n \"patience\",\n \"workers\",\n \"seed\",\n \"close_mosaic\",\n \"mask_ratio\",\n \"max_det\",\n \"vid_stride\",\n \"line_width\",\n \"nbs\",\n \"save_period\",\n }\n)", "chunk_type": "variable", "name": "CFG_INT_KEYS", "file_path": "ultralytics\\ultralytics\\cfg\\__init__.py", "start_line": 186, "end_line": 200, "start_col": 0, "end_col": 1, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_CFG_INT_KEYS_424546cf" }, { "content": "CFG_BOOL_KEYS = frozenset(\n { # boolean-only arguments\n \"save\",\n \"exist_ok\",\n \"verbose\",\n \"deterministic\",\n \"single_cls\",\n \"rect\",\n \"cos_lr\",\n \"overlap_mask\",\n \"val\",\n \"save_json\",\n \"half\",\n \"dnn\",\n \"plots\",\n \"show\",\n \"save_txt\",\n \"save_conf\",\n \"save_crop\",\n \"save_frames\",\n \"show_labels\",\n \"show_conf\",\n \"visualize\",\n \"augment\",\n \"agnostic_nms\",\n \"retina_masks\",\n \"show_boxes\",\n \"keras\",\n \"optimize\",\n \"int8\",\n \"dynamic\",\n \"simplify\",\n \"nms\",\n \"profile\",\n \"multi_scale\",\n }\n)", "chunk_type": "variable", "name": "CFG_BOOL_KEYS", "file_path": "ultralytics\\ultralytics\\cfg\\__init__.py", "start_line": 201, "end_line": 237, "start_col": 0, "end_col": 1, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_CFG_BOOL_KEYS_6fd1325b" }, { "content": "def cfg2dict(cfg: Union[str, Path, Dict, SimpleNamespace]) -> Dict:\n \"\"\"\n Convert a configuration object to a dictionary.\n\n Args:\n cfg (str | Path | Dict | SimpleNamespace): Configuration object to be converted. Can be a file path,\n a string, a dictionary, or a SimpleNamespace object.\n\n Returns:\n (dict): Configuration object in dictionary format.\n\n Examples:\n Convert a YAML file path to a dictionary:\n >>> config_dict = cfg2dict(\"config.yaml\")\n\n Convert a SimpleNamespace to a dictionary:\n >>> from types import SimpleNamespace\n >>> config_sn = SimpleNamespace(param1=\"value1\", param2=\"value2\")\n >>> config_dict = cfg2dict(config_sn)\n\n Pass through an already existing dictionary:\n >>> config_dict = cfg2dict({\"param1\": \"value1\", \"param2\": \"value2\"})\n\n Notes:\n - If cfg is a path or string, it's loaded as YAML and converted to a dictionary.\n - If cfg is a SimpleNamespace object, it's converted to a dictionary using vars().\n - If cfg is already a dictionary, it's returned unchanged.\n \"\"\"\n if isinstance(cfg, (str, Path)):\n cfg = YAML.load(cfg) # load dict\n elif isinstance(cfg, SimpleNamespace):\n cfg = vars(cfg) # convert to dict\n return cfg", "chunk_type": "function", "name": "cfg2dict", "file_path": "ultralytics\\ultralytics\\cfg\\__init__.py", "start_line": 240, "end_line": 272, "start_col": 0, "end_col": 14, "parent_name": null, "docstring": "Convert a configuration object to a dictionary.\n\nArgs:\n cfg (str | Path | Dict | SimpleNamespace): Configuration object to be converted. Can be a file path,\n a string, a dictionary, or a SimpleNamespace object.\n\nReturns:\n (dict): Configuration object in dictionary format.\n\nExamples:\n Convert a YAML file path to a dictionary:\n >>> config_dict = cfg2dict(\"config.yaml\")\n\n Convert a SimpleNamespace to a dictionary:\n >>> from types import SimpleNamespace\n >>> config_sn = SimpleNamespace(param1=\"value1\", param2=\"value2\")\n >>> config_dict = cfg2dict(config_sn)\n\n Pass through an already existing dictionary:\n >>> config_dict = cfg2dict({\"param1\": \"value1\", \"param2\": \"value2\"})\n\nNotes:\n - If cfg is a path or string, it's loaded as YAML and converted to a dictionary.\n - If cfg is a SimpleNamespace object, it's converted to a dictionary using vars().\n - If cfg is already a dictionary, it's returned unchanged.", "parameters": [ "cfg: Union[str, Path, Dict, SimpleNamespace]" ], "return_type": "Dict", "decorators": [], "complexity_score": 3, "dependencies": [ "shutil", "subprocess", "sys", "pathlib.Path", "types.SimpleNamespace", "typing.Any", "typing.Dict", "typing.List", "typing.Union", "ultralytics.__version__", "ultralytics.utils.ASSETS", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.DEFAULT_CFG_PATH", "ultralytics.utils.IS_VSCODE", "ultralytics.utils.LOGGER", "ultralytics.utils.RANK", "ultralytics.utils.ROOT", "ultralytics.utils.RUNS_DIR", "ultralytics.utils.SETTINGS", "ultralytics.utils.SETTINGS_FILE", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.YAML", "ultralytics.utils.IterableSimpleNamespace", "ultralytics.utils.checks", "ultralytics.utils.colorstr", "ultralytics.utils.deprecation_warn", "ultralytics.utils.vscode_msg", "ultralytics.hub", "ultralytics.solutions.config.SolutionConfig", "ultralytics.utils.files.increment_path", "difflib.get_close_matches", "cv2", "ultralytics.solutions", "ultralytics.RTDETR", "ultralytics.FastSAM", "ultralytics.SAM", "ultralytics.YOLO" ], "chunk_id": "function_cfg2dict_3dd1f37a" }, { "content": "def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG_DICT, overrides: Dict = None) -> SimpleNamespace:\n \"\"\"\n Load and merge configuration data from a file or dictionary, with optional overrides.\n\n Args:\n cfg (str | Path | Dict | SimpleNamespace): Configuration data source. Can be a file path, dictionary, or\n SimpleNamespace object.\n overrides (Dict | None): Dictionary containing key-value pairs to override the base configuration.\n\n Returns:\n (SimpleNamespace): Namespace containing the merged configuration arguments.\n\n Examples:\n >>> from ultralytics.cfg import get_cfg\n >>> config = get_cfg() # Load default configuration\n >>> config_with_overrides = get_cfg(\"path/to/config.yaml\", overrides={\"epochs\": 50, \"batch_size\": 16})\n\n Notes:\n - If both `cfg` and `overrides` are provided, the values in `overrides` will take precedence.\n - Special handling ensures alignment and correctness of the configuration, such as converting numeric\n `project` and `name` to strings and validating configuration keys and values.\n - The function performs type and value checks on the configuration data.\n \"\"\"\n cfg = cfg2dict(cfg)\n\n # Merge overrides\n if overrides:\n overrides = cfg2dict(overrides)\n if \"save_dir\" not in cfg:\n overrides.pop(\"save_dir\", None) # special override keys to ignore\n check_dict_alignment(cfg, overrides)\n cfg = {**cfg, **overrides} # merge cfg and overrides dicts (prefer overrides)\n\n # Special handling for numeric project/name\n for k in \"project\", \"name\":\n if k in cfg and isinstance(cfg[k], (int, float)):\n cfg[k] = str(cfg[k])\n if cfg.get(\"name\") == \"model\": # assign model to 'name' arg\n cfg[\"name\"] = str(cfg.get(\"model\", \"\")).partition(\".\")[0]\n LOGGER.warning(f\"'name=model' automatically updated to 'name={cfg['name']}'.\")\n\n # Type and Value checks\n check_cfg(cfg)\n\n # Return instance\n return IterableSimpleNamespace(**cfg)", "chunk_type": "function", "name": "get_cfg", "file_path": "ultralytics\\ultralytics\\cfg\\__init__.py", "start_line": 275, "end_line": 320, "start_col": 0, "end_col": 41, "parent_name": null, "docstring": "Load and merge configuration data from a file or dictionary, with optional overrides.\n\nArgs:\n cfg (str | Path | Dict | SimpleNamespace): Configuration data source. Can be a file path, dictionary, or\n SimpleNamespace object.\n overrides (Dict | None): Dictionary containing key-value pairs to override the base configuration.\n\nReturns:\n (SimpleNamespace): Namespace containing the merged configuration arguments.\n\nExamples:\n >>> from ultralytics.cfg import get_cfg\n >>> config = get_cfg() # Load default configuration\n >>> config_with_overrides = get_cfg(\"path/to/config.yaml\", overrides={\"epochs\": 50, \"batch_size\": 16})\n\nNotes:\n - If both `cfg` and `overrides` are provided, the values in `overrides` will take precedence.\n - Special handling ensures alignment and correctness of the configuration, such as converting numeric\n `project` and `name` to strings and validating configuration keys and values.\n - The function performs type and value checks on the configuration data.", "parameters": [ "cfg: Union[str, Path, Dict, SimpleNamespace]", "overrides: Dict" ], "return_type": "SimpleNamespace", "decorators": [], "complexity_score": 6, "dependencies": [ "shutil", "subprocess", "sys", "pathlib.Path", "types.SimpleNamespace", "typing.Any", "typing.Dict", "typing.List", "typing.Union", "ultralytics.__version__", "ultralytics.utils.ASSETS", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.DEFAULT_CFG_PATH", "ultralytics.utils.IS_VSCODE", "ultralytics.utils.LOGGER", "ultralytics.utils.RANK", "ultralytics.utils.ROOT", "ultralytics.utils.RUNS_DIR", "ultralytics.utils.SETTINGS", "ultralytics.utils.SETTINGS_FILE", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.YAML", "ultralytics.utils.IterableSimpleNamespace", "ultralytics.utils.checks", "ultralytics.utils.colorstr", "ultralytics.utils.deprecation_warn", "ultralytics.utils.vscode_msg", "ultralytics.hub", "ultralytics.solutions.config.SolutionConfig", "ultralytics.utils.files.increment_path", "difflib.get_close_matches", "cv2", "ultralytics.solutions", "ultralytics.RTDETR", "ultralytics.FastSAM", "ultralytics.SAM", "ultralytics.YOLO" ], "chunk_id": "function_get_cfg_0a7f7b89" }, { "content": "def check_cfg(cfg: Dict, hard: bool = True) -> None:\n \"\"\"\n Check configuration argument types and values for the Ultralytics library.\n\n This function validates the types and values of configuration arguments, ensuring correctness and converting\n them if necessary. It checks for specific key types defined in global variables such as `CFG_FLOAT_KEYS`,\n `CFG_FRACTION_KEYS`, `CFG_INT_KEYS`, and `CFG_BOOL_KEYS`.\n\n Args:\n cfg (dict): Configuration dictionary to validate.\n hard (bool): If True, raises exceptions for invalid types and values; if False, attempts to convert them.\n\n Examples:\n >>> config = {\n ... \"epochs\": 50, # valid integer\n ... \"lr0\": 0.01, # valid float\n ... \"momentum\": 1.2, # invalid float (out of 0.0-1.0 range)\n ... \"save\": \"true\", # invalid bool\n ... }\n >>> check_cfg(config, hard=False)\n >>> print(config)\n {'epochs': 50, 'lr0': 0.01, 'momentum': 1.2, 'save': False} # corrected 'save' key\n\n Notes:\n - The function modifies the input dictionary in-place.\n - None values are ignored as they may be from optional arguments.\n - Fraction keys are checked to be within the range [0.0, 1.0].\n \"\"\"\n for k, v in cfg.items():\n if v is not None: # None values may be from optional args\n if k in CFG_FLOAT_KEYS and not isinstance(v, (int, float)):\n if hard:\n raise TypeError(\n f\"'{k}={v}' is of invalid type {type(v).__name__}. \"\n f\"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')\"\n )\n cfg[k] = float(v)\n elif k in CFG_FRACTION_KEYS:\n if not isinstance(v, (int, float)):\n if hard:\n raise TypeError(\n f\"'{k}={v}' is of invalid type {type(v).__name__}. \"\n f\"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')\"\n )\n cfg[k] = v = float(v)\n if not (0.0 <= v <= 1.0):\n raise ValueError(f\"'{k}={v}' is an invalid value. Valid '{k}' values are between 0.0 and 1.0.\")\n elif k in CFG_INT_KEYS and not isinstance(v, int):\n if hard:\n raise TypeError(\n f\"'{k}={v}' is of invalid type {type(v).__name__}. '{k}' must be an int (i.e. '{k}=8')\"\n )\n cfg[k] = int(v)\n elif k in CFG_BOOL_KEYS and not isinstance(v, bool):\n if hard:\n raise TypeError(\n f\"'{k}={v}' is of invalid type {type(v).__name__}. \"\n f\"'{k}' must be a bool (i.e. '{k}=True' or '{k}=False')\"\n )\n cfg[k] = bool(v)", "chunk_type": "function", "name": "check_cfg", "file_path": "ultralytics\\ultralytics\\cfg\\__init__.py", "start_line": 323, "end_line": 382, "start_col": 0, "end_col": 32, "parent_name": null, "docstring": "Check configuration argument types and values for the Ultralytics library.\n\nThis function validates the types and values of configuration arguments, ensuring correctness and converting\nthem if necessary. It checks for specific key types defined in global variables such as `CFG_FLOAT_KEYS`,\n`CFG_FRACTION_KEYS`, `CFG_INT_KEYS`, and `CFG_BOOL_KEYS`.\n\nArgs:\n cfg (dict): Configuration dictionary to validate.\n hard (bool): If True, raises exceptions for invalid types and values; if False, attempts to convert them.\n\nExamples:\n >>> config = {\n ... \"epochs\": 50, # valid integer\n ... \"lr0\": 0.01, # valid float\n ... \"momentum\": 1.2, # invalid float (out of 0.0-1.0 range)\n ... \"save\": \"true\", # invalid bool\n ... }\n >>> check_cfg(config, hard=False)\n >>> print(config)\n {'epochs': 50, 'lr0': 0.01, 'momentum': 1.2, 'save': False} # corrected 'save' key\n\nNotes:\n - The function modifies the input dictionary in-place.\n - None values are ignored as they may be from optional arguments.\n - Fraction keys are checked to be within the range [0.0, 1.0].", "parameters": [ "cfg: Dict", "hard: bool" ], "return_type": "None", "decorators": [], "complexity_score": 13, "dependencies": [ "shutil", "subprocess", "sys", "pathlib.Path", "types.SimpleNamespace", "typing.Any", "typing.Dict", "typing.List", "typing.Union", "ultralytics.__version__", "ultralytics.utils.ASSETS", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.DEFAULT_CFG_PATH", "ultralytics.utils.IS_VSCODE", "ultralytics.utils.LOGGER", "ultralytics.utils.RANK", "ultralytics.utils.ROOT", "ultralytics.utils.RUNS_DIR", "ultralytics.utils.SETTINGS", "ultralytics.utils.SETTINGS_FILE", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.YAML", "ultralytics.utils.IterableSimpleNamespace", "ultralytics.utils.checks", "ultralytics.utils.colorstr", "ultralytics.utils.deprecation_warn", "ultralytics.utils.vscode_msg", "ultralytics.hub", "ultralytics.solutions.config.SolutionConfig", "ultralytics.utils.files.increment_path", "difflib.get_close_matches", "cv2", "ultralytics.solutions", "ultralytics.RTDETR", "ultralytics.FastSAM", "ultralytics.SAM", "ultralytics.YOLO" ], "chunk_id": "function_check_cfg_93091c0e" }, { "content": "def get_save_dir(args: SimpleNamespace, name: str = None) -> Path:\n \"\"\"\n Return the directory path for saving outputs, derived from arguments or default settings.\n\n Args:\n args (SimpleNamespace): Namespace object containing configurations such as 'project', 'name', 'task',\n 'mode', and 'save_dir'.\n name (str | None): Optional name for the output directory. If not provided, it defaults to 'args.name'\n or the 'args.mode'.\n\n Returns:\n (Path): Directory path where outputs should be saved.\n\n Examples:\n >>> from types import SimpleNamespace\n >>> args = SimpleNamespace(project=\"my_project\", task=\"detect\", mode=\"train\", exist_ok=True)\n >>> save_dir = get_save_dir(args)\n >>> print(save_dir)\n my_project/detect/train\n \"\"\"\n if getattr(args, \"save_dir\", None):\n save_dir = args.save_dir\n else:\n from ultralytics.utils.files import increment_path\n\n project = args.project or (ROOT.parent / \"tests/tmp/runs\" if TESTS_RUNNING else RUNS_DIR) / args.task\n name = name or args.name or f\"{args.mode}\"\n save_dir = increment_path(Path(project) / name, exist_ok=args.exist_ok if RANK in {-1, 0} else True)\n\n return Path(save_dir)", "chunk_type": "function", "name": "get_save_dir", "file_path": "ultralytics\\ultralytics\\cfg\\__init__.py", "start_line": 385, "end_line": 414, "start_col": 0, "end_col": 25, "parent_name": null, "docstring": "Return the directory path for saving outputs, derived from arguments or default settings.\n\nArgs:\n args (SimpleNamespace): Namespace object containing configurations such as 'project', 'name', 'task',\n 'mode', and 'save_dir'.\n name (str | None): Optional name for the output directory. If not provided, it defaults to 'args.name'\n or the 'args.mode'.\n\nReturns:\n (Path): Directory path where outputs should be saved.\n\nExamples:\n >>> from types import SimpleNamespace\n >>> args = SimpleNamespace(project=\"my_project\", task=\"detect\", mode=\"train\", exist_ok=True)\n >>> save_dir = get_save_dir(args)\n >>> print(save_dir)\n my_project/detect/train", "parameters": [ "args: SimpleNamespace", "name: str" ], "return_type": "Path", "decorators": [], "complexity_score": 2, "dependencies": [ "shutil", "subprocess", "sys", "pathlib.Path", "types.SimpleNamespace", "typing.Any", "typing.Dict", "typing.List", "typing.Union", "ultralytics.__version__", "ultralytics.utils.ASSETS", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.DEFAULT_CFG_PATH", "ultralytics.utils.IS_VSCODE", "ultralytics.utils.LOGGER", "ultralytics.utils.RANK", "ultralytics.utils.ROOT", "ultralytics.utils.RUNS_DIR", "ultralytics.utils.SETTINGS", "ultralytics.utils.SETTINGS_FILE", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.YAML", "ultralytics.utils.IterableSimpleNamespace", "ultralytics.utils.checks", "ultralytics.utils.colorstr", "ultralytics.utils.deprecation_warn", "ultralytics.utils.vscode_msg", "ultralytics.hub", "ultralytics.solutions.config.SolutionConfig", "ultralytics.utils.files.increment_path", "difflib.get_close_matches", "cv2", "ultralytics.solutions", "ultralytics.RTDETR", "ultralytics.FastSAM", "ultralytics.SAM", "ultralytics.YOLO" ], "chunk_id": "function_get_save_dir_546ed6fe" }, { "content": "def _handle_deprecation(custom: Dict) -> Dict:\n \"\"\"\n Handle deprecated configuration keys by mapping them to current equivalents with deprecation warnings.\n\n Args:\n custom (dict): Configuration dictionary potentially containing deprecated keys.\n\n Returns:\n (dict): Updated configuration dictionary with deprecated keys replaced.\n\n Examples:\n >>> custom_config = {\"boxes\": True, \"hide_labels\": \"False\", \"line_thickness\": 2}\n >>> _handle_deprecation(custom_config)\n >>> print(custom_config)\n {'show_boxes': True, 'show_labels': True, 'line_width': 2}\n\n Notes:\n This function modifies the input dictionary in-place, replacing deprecated keys with their current\n equivalents. It also handles value conversions where necessary, such as inverting boolean values for\n 'hide_labels' and 'hide_conf'.\n \"\"\"\n deprecated_mappings = {\n \"boxes\": (\"show_boxes\", lambda v: v),\n \"hide_labels\": (\"show_labels\", lambda v: not bool(v)),\n \"hide_conf\": (\"show_conf\", lambda v: not bool(v)),\n \"line_thickness\": (\"line_width\", lambda v: v),\n }\n removed_keys = {\"label_smoothing\", \"save_hybrid\", \"crop_fraction\"}\n\n for old_key, (new_key, transform) in deprecated_mappings.items():\n if old_key not in custom:\n continue\n deprecation_warn(old_key, new_key)\n custom[new_key] = transform(custom.pop(old_key))\n\n for key in removed_keys:\n if key not in custom:\n continue\n deprecation_warn(key)\n custom.pop(key)\n\n return custom", "chunk_type": "function", "name": "_handle_deprecation", "file_path": "ultralytics\\ultralytics\\cfg\\__init__.py", "start_line": 417, "end_line": 458, "start_col": 0, "end_col": 17, "parent_name": null, "docstring": "Handle deprecated configuration keys by mapping them to current equivalents with deprecation warnings.\n\nArgs:\n custom (dict): Configuration dictionary potentially containing deprecated keys.\n\nReturns:\n (dict): Updated configuration dictionary with deprecated keys replaced.\n\nExamples:\n >>> custom_config = {\"boxes\": True, \"hide_labels\": \"False\", \"line_thickness\": 2}\n >>> _handle_deprecation(custom_config)\n >>> print(custom_config)\n {'show_boxes': True, 'show_labels': True, 'line_width': 2}\n\nNotes:\n This function modifies the input dictionary in-place, replacing deprecated keys with their current\n equivalents. It also handles value conversions where necessary, such as inverting boolean values for\n 'hide_labels' and 'hide_conf'.", "parameters": [ "custom: Dict" ], "return_type": "Dict", "decorators": [], "complexity_score": 5, "dependencies": [ "shutil", "subprocess", "sys", "pathlib.Path", "types.SimpleNamespace", "typing.Any", "typing.Dict", "typing.List", "typing.Union", "ultralytics.__version__", "ultralytics.utils.ASSETS", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.DEFAULT_CFG_PATH", "ultralytics.utils.IS_VSCODE", "ultralytics.utils.LOGGER", "ultralytics.utils.RANK", "ultralytics.utils.ROOT", "ultralytics.utils.RUNS_DIR", "ultralytics.utils.SETTINGS", "ultralytics.utils.SETTINGS_FILE", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.YAML", "ultralytics.utils.IterableSimpleNamespace", "ultralytics.utils.checks", "ultralytics.utils.colorstr", "ultralytics.utils.deprecation_warn", "ultralytics.utils.vscode_msg", "ultralytics.hub", "ultralytics.solutions.config.SolutionConfig", "ultralytics.utils.files.increment_path", "difflib.get_close_matches", "cv2", "ultralytics.solutions", "ultralytics.RTDETR", "ultralytics.FastSAM", "ultralytics.SAM", "ultralytics.YOLO" ], "chunk_id": "function__handle_deprecation_5d0de858" }, { "content": "def check_dict_alignment(base: Dict, custom: Dict, e: Exception = None) -> None:\n \"\"\"\n Check alignment between custom and base configuration dictionaries, handling deprecated keys and providing error\n messages for mismatched keys.\n\n Args:\n base (dict): The base configuration dictionary containing valid keys.\n custom (dict): The custom configuration dictionary to be checked for alignment.\n e (Exception | None): Optional error instance passed by the calling function.\n\n Raises:\n SystemExit: If mismatched keys are found between the custom and base dictionaries.\n\n Examples:\n >>> base_cfg = {\"epochs\": 50, \"lr0\": 0.01, \"batch_size\": 16}\n >>> custom_cfg = {\"epoch\": 100, \"lr\": 0.02, \"batch_size\": 32}\n >>> try:\n ... check_dict_alignment(base_cfg, custom_cfg)\n ... except SystemExit:\n ... print(\"Mismatched keys found\")\n\n Notes:\n - Suggests corrections for mismatched keys based on similarity to valid keys.\n - Automatically replaces deprecated keys in the custom configuration with updated equivalents.\n - Prints detailed error messages for each mismatched key to help users correct their configurations.\n \"\"\"\n custom = _handle_deprecation(custom)\n base_keys, custom_keys = (frozenset(x.keys()) for x in (base, custom))\n if mismatched := [k for k in custom_keys if k not in base_keys]:\n from difflib import get_close_matches\n\n string = \"\"\n for x in mismatched:\n matches = get_close_matches(x, base_keys) # key list\n matches = [f\"{k}={base[k]}\" if base.get(k) is not None else k for k in matches]\n match_str = f\"Similar arguments are i.e. {matches}.\" if matches else \"\"\n string += f\"'{colorstr('red', 'bold', x)}' is not a valid YOLO argument. {match_str}\\n\"\n raise SyntaxError(string + CLI_HELP_MSG) from e", "chunk_type": "function", "name": "check_dict_alignment", "file_path": "ultralytics\\ultralytics\\cfg\\__init__.py", "start_line": 461, "end_line": 498, "start_col": 0, "end_col": 55, "parent_name": null, "docstring": "Check alignment between custom and base configuration dictionaries, handling deprecated keys and providing error\nmessages for mismatched keys.\n\nArgs:\n base (dict): The base configuration dictionary containing valid keys.\n custom (dict): The custom configuration dictionary to be checked for alignment.\n e (Exception | None): Optional error instance passed by the calling function.\n\nRaises:\n SystemExit: If mismatched keys are found between the custom and base dictionaries.\n\nExamples:\n >>> base_cfg = {\"epochs\": 50, \"lr0\": 0.01, \"batch_size\": 16}\n >>> custom_cfg = {\"epoch\": 100, \"lr\": 0.02, \"batch_size\": 32}\n >>> try:\n ... check_dict_alignment(base_cfg, custom_cfg)\n ... except SystemExit:\n ... print(\"Mismatched keys found\")\n\nNotes:\n - Suggests corrections for mismatched keys based on similarity to valid keys.\n - Automatically replaces deprecated keys in the custom configuration with updated equivalents.\n - Prints detailed error messages for each mismatched key to help users correct their configurations.", "parameters": [ "base: Dict", "custom: Dict", "e: Exception" ], "return_type": "None", "decorators": [], "complexity_score": 6, "dependencies": [ "shutil", "subprocess", "sys", "pathlib.Path", "types.SimpleNamespace", "typing.Any", "typing.Dict", "typing.List", "typing.Union", "ultralytics.__version__", "ultralytics.utils.ASSETS", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.DEFAULT_CFG_PATH", "ultralytics.utils.IS_VSCODE", "ultralytics.utils.LOGGER", "ultralytics.utils.RANK", "ultralytics.utils.ROOT", "ultralytics.utils.RUNS_DIR", "ultralytics.utils.SETTINGS", "ultralytics.utils.SETTINGS_FILE", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.YAML", "ultralytics.utils.IterableSimpleNamespace", "ultralytics.utils.checks", "ultralytics.utils.colorstr", "ultralytics.utils.deprecation_warn", "ultralytics.utils.vscode_msg", "ultralytics.hub", "ultralytics.solutions.config.SolutionConfig", "ultralytics.utils.files.increment_path", "difflib.get_close_matches", "cv2", "ultralytics.solutions", "ultralytics.RTDETR", "ultralytics.FastSAM", "ultralytics.SAM", "ultralytics.YOLO" ], "chunk_id": "function_check_dict_alignment_9890b4c4" }, { "content": "def merge_equals_args(args: List[str]) -> List[str]:\n \"\"\"\n Merge arguments around isolated '=' in a list of strings and join fragments with brackets.\n\n This function handles the following cases:\n 1. ['arg', '=', 'val'] becomes ['arg=val']\n 2. ['arg=', 'val'] becomes ['arg=val']\n 3. ['arg', '=val'] becomes ['arg=val']\n 4. Joins fragments with brackets, e.g., ['imgsz=[3,', '640,', '640]'] becomes ['imgsz=[3,640,640]']\n\n Args:\n args (List[str]): A list of strings where each element represents an argument or fragment.\n\n Returns:\n (List[str]): A list of strings where the arguments around isolated '=' are merged and fragments with brackets are joined.\n\n Examples:\n >>> args = [\"arg1\", \"=\", \"value\", \"arg2=\", \"value2\", \"arg3\", \"=value3\", \"imgsz=[3,\", \"640,\", \"640]\"]\n >>> merge_equals_args(args)\n ['arg1=value', 'arg2=value2', 'arg3=value3', 'imgsz=[3,640,640]']\n \"\"\"\n new_args = []\n current = \"\"\n depth = 0\n\n i = 0\n while i < len(args):\n arg = args[i]\n\n # Handle equals sign merging\n if arg == \"=\" and 0 < i < len(args) - 1: # merge ['arg', '=', 'val']\n new_args[-1] += f\"={args[i + 1]}\"\n i += 2\n continue\n elif arg.endswith(\"=\") and i < len(args) - 1 and \"=\" not in args[i + 1]: # merge ['arg=', 'val']\n new_args.append(f\"{arg}{args[i + 1]}\")\n i += 2\n continue\n elif arg.startswith(\"=\") and i > 0: # merge ['arg', '=val']\n new_args[-1] += arg\n i += 1\n continue\n\n # Handle bracket joining\n depth += arg.count(\"[\") - arg.count(\"]\")\n current += arg\n if depth == 0:\n new_args.append(current)\n current = \"\"\n\n i += 1\n\n # Append any remaining current string\n if current:\n new_args.append(current)\n\n return new_args", "chunk_type": "function", "name": "merge_equals_args", "file_path": "ultralytics\\ultralytics\\cfg\\__init__.py", "start_line": 501, "end_line": 557, "start_col": 0, "end_col": 19, "parent_name": null, "docstring": "Merge arguments around isolated '=' in a list of strings and join fragments with brackets.\n\nThis function handles the following cases:\n 1. ['arg', '=', 'val'] becomes ['arg=val']\n 2. ['arg=', 'val'] becomes ['arg=val']\n 3. ['arg', '=val'] becomes ['arg=val']\n 4. Joins fragments with brackets, e.g., ['imgsz=[3,', '640,', '640]'] becomes ['imgsz=[3,640,640]']\n\nArgs:\n args (List[str]): A list of strings where each element represents an argument or fragment.\n\nReturns:\n (List[str]): A list of strings where the arguments around isolated '=' are merged and fragments with brackets are joined.\n\nExamples:\n >>> args = [\"arg1\", \"=\", \"value\", \"arg2=\", \"value2\", \"arg3\", \"=value3\", \"imgsz=[3,\", \"640,\", \"640]\"]\n >>> merge_equals_args(args)\n ['arg1=value', 'arg2=value2', 'arg3=value3', 'imgsz=[3,640,640]']", "parameters": [ "args: List[str]" ], "return_type": "List[str]", "decorators": [], "complexity_score": 7, "dependencies": [ "shutil", "subprocess", "sys", "pathlib.Path", "types.SimpleNamespace", "typing.Any", "typing.Dict", "typing.List", "typing.Union", "ultralytics.__version__", "ultralytics.utils.ASSETS", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.DEFAULT_CFG_PATH", "ultralytics.utils.IS_VSCODE", "ultralytics.utils.LOGGER", "ultralytics.utils.RANK", "ultralytics.utils.ROOT", "ultralytics.utils.RUNS_DIR", "ultralytics.utils.SETTINGS", "ultralytics.utils.SETTINGS_FILE", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.YAML", "ultralytics.utils.IterableSimpleNamespace", "ultralytics.utils.checks", "ultralytics.utils.colorstr", "ultralytics.utils.deprecation_warn", "ultralytics.utils.vscode_msg", "ultralytics.hub", "ultralytics.solutions.config.SolutionConfig", "ultralytics.utils.files.increment_path", "difflib.get_close_matches", "cv2", "ultralytics.solutions", "ultralytics.RTDETR", "ultralytics.FastSAM", "ultralytics.SAM", "ultralytics.YOLO" ], "chunk_id": "function_merge_equals_args_8b6c7c35" }, { "content": "def handle_yolo_hub(args: List[str]) -> None:\n \"\"\"\n Handle Ultralytics HUB command-line interface (CLI) commands for authentication.\n\n This function processes Ultralytics HUB CLI commands such as login and logout. It should be called when executing a\n script with arguments related to HUB authentication.\n\n Args:\n args (List[str]): A list of command line arguments. The first argument should be either 'login'\n or 'logout'. For 'login', an optional second argument can be the API key.\n\n Examples:\n $ yolo login YOUR_API_KEY\n\n Notes:\n - The function imports the 'hub' module from ultralytics to perform login and logout operations.\n - For the 'login' command, if no API key is provided, an empty string is passed to the login function.\n - The 'logout' command does not require any additional arguments.\n \"\"\"\n from ultralytics import hub\n\n if args[0] == \"login\":\n key = args[1] if len(args) > 1 else \"\"\n # Log in to Ultralytics HUB using the provided API key\n hub.login(key)\n elif args[0] == \"logout\":\n # Log out from Ultralytics HUB\n hub.logout()", "chunk_type": "function", "name": "handle_yolo_hub", "file_path": "ultralytics\\ultralytics\\cfg\\__init__.py", "start_line": 560, "end_line": 587, "start_col": 0, "end_col": 20, "parent_name": null, "docstring": "Handle Ultralytics HUB command-line interface (CLI) commands for authentication.\n\nThis function processes Ultralytics HUB CLI commands such as login and logout. It should be called when executing a\nscript with arguments related to HUB authentication.\n\nArgs:\n args (List[str]): A list of command line arguments. The first argument should be either 'login'\n or 'logout'. For 'login', an optional second argument can be the API key.\n\nExamples:\n $ yolo login YOUR_API_KEY\n\nNotes:\n - The function imports the 'hub' module from ultralytics to perform login and logout operations.\n - For the 'login' command, if no API key is provided, an empty string is passed to the login function.\n - The 'logout' command does not require any additional arguments.", "parameters": [ "args: List[str]" ], "return_type": "None", "decorators": [], "complexity_score": 3, "dependencies": [ "shutil", "subprocess", "sys", "pathlib.Path", "types.SimpleNamespace", "typing.Any", "typing.Dict", "typing.List", "typing.Union", "ultralytics.__version__", "ultralytics.utils.ASSETS", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.DEFAULT_CFG_PATH", "ultralytics.utils.IS_VSCODE", "ultralytics.utils.LOGGER", "ultralytics.utils.RANK", "ultralytics.utils.ROOT", "ultralytics.utils.RUNS_DIR", "ultralytics.utils.SETTINGS", "ultralytics.utils.SETTINGS_FILE", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.YAML", "ultralytics.utils.IterableSimpleNamespace", "ultralytics.utils.checks", "ultralytics.utils.colorstr", "ultralytics.utils.deprecation_warn", "ultralytics.utils.vscode_msg", "ultralytics.hub", "ultralytics.solutions.config.SolutionConfig", "ultralytics.utils.files.increment_path", "difflib.get_close_matches", "cv2", "ultralytics.solutions", "ultralytics.RTDETR", "ultralytics.FastSAM", "ultralytics.SAM", "ultralytics.YOLO" ], "chunk_id": "function_handle_yolo_hub_7fb15a41" }, { "content": "def handle_yolo_settings(args: List[str]) -> None:\n \"\"\"\n Handle YOLO settings command-line interface (CLI) commands.\n\n This function processes YOLO settings CLI commands such as reset and updating individual settings. It should be\n called when executing a script with arguments related to YOLO settings management.\n\n Args:\n args (List[str]): A list of command line arguments for YOLO settings management.\n\n Examples:\n >>> handle_yolo_settings([\"reset\"]) # Reset YOLO settings\n >>> handle_yolo_settings([\"default_cfg_path=yolo11n.yaml\"]) # Update a specific setting\n\n Notes:\n - If no arguments are provided, the function will display the current settings.\n - The 'reset' command will delete the existing settings file and create new default settings.\n - Other arguments are treated as key-value pairs to update specific settings.\n - The function will check for alignment between the provided settings and the existing ones.\n - After processing, the updated settings will be displayed.\n - For more information on handling YOLO settings, visit:\n https://docs.ultralytics.com/quickstart/#ultralytics-settings\n \"\"\"\n url = \"https://docs.ultralytics.com/quickstart/#ultralytics-settings\" # help URL\n try:\n if any(args):\n if args[0] == \"reset\":\n SETTINGS_FILE.unlink() # delete the settings file\n SETTINGS.reset() # create new settings\n LOGGER.info(\"Settings reset successfully\") # inform the user that settings have been reset\n else: # save a new setting\n new = dict(parse_key_value_pair(a) for a in args)\n check_dict_alignment(SETTINGS, new)\n SETTINGS.update(new)\n for k, v in new.items():\n LOGGER.info(f\"✅ Updated '{k}={v}'\")\n\n LOGGER.info(SETTINGS) # print the current settings\n LOGGER.info(f\"💡 Learn more about Ultralytics Settings at {url}\")\n except Exception as e:\n LOGGER.warning(f\"settings error: '{e}'. Please see {url} for help.\")", "chunk_type": "function", "name": "handle_yolo_settings", "file_path": "ultralytics\\ultralytics\\cfg\\__init__.py", "start_line": 590, "end_line": 630, "start_col": 0, "end_col": 76, "parent_name": null, "docstring": "Handle YOLO settings command-line interface (CLI) commands.\n\nThis function processes YOLO settings CLI commands such as reset and updating individual settings. It should be\ncalled when executing a script with arguments related to YOLO settings management.\n\nArgs:\n args (List[str]): A list of command line arguments for YOLO settings management.\n\nExamples:\n >>> handle_yolo_settings([\"reset\"]) # Reset YOLO settings\n >>> handle_yolo_settings([\"default_cfg_path=yolo11n.yaml\"]) # Update a specific setting\n\nNotes:\n - If no arguments are provided, the function will display the current settings.\n - The 'reset' command will delete the existing settings file and create new default settings.\n - Other arguments are treated as key-value pairs to update specific settings.\n - The function will check for alignment between the provided settings and the existing ones.\n - After processing, the updated settings will be displayed.\n - For more information on handling YOLO settings, visit:\n https://docs.ultralytics.com/quickstart/#ultralytics-settings", "parameters": [ "args: List[str]" ], "return_type": "None", "decorators": [], "complexity_score": 6, "dependencies": [ "shutil", "subprocess", "sys", "pathlib.Path", "types.SimpleNamespace", "typing.Any", "typing.Dict", "typing.List", "typing.Union", "ultralytics.__version__", "ultralytics.utils.ASSETS", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.DEFAULT_CFG_PATH", "ultralytics.utils.IS_VSCODE", "ultralytics.utils.LOGGER", "ultralytics.utils.RANK", "ultralytics.utils.ROOT", "ultralytics.utils.RUNS_DIR", "ultralytics.utils.SETTINGS", "ultralytics.utils.SETTINGS_FILE", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.YAML", "ultralytics.utils.IterableSimpleNamespace", "ultralytics.utils.checks", "ultralytics.utils.colorstr", "ultralytics.utils.deprecation_warn", "ultralytics.utils.vscode_msg", "ultralytics.hub", "ultralytics.solutions.config.SolutionConfig", "ultralytics.utils.files.increment_path", "difflib.get_close_matches", "cv2", "ultralytics.solutions", "ultralytics.RTDETR", "ultralytics.FastSAM", "ultralytics.SAM", "ultralytics.YOLO" ], "chunk_id": "function_handle_yolo_settings_67dd489b" }, { "content": "def handle_yolo_solutions(args: List[str]) -> None:\n \"\"\"\n Process YOLO solutions arguments and run the specified computer vision solutions pipeline.\n\n Args:\n args (List[str]): Command-line arguments for configuring and running the Ultralytics YOLO\n solutions: https://docs.ultralytics.com/solutions/, It can include solution name, source,\n and other configuration parameters.\n\n Examples:\n Run people counting solution with default settings:\n >>> handle_yolo_solutions([\"count\"])\n\n Run analytics with custom configuration:\n >>> handle_yolo_solutions([\"analytics\", \"conf=0.25\", \"source=path/to/video.mp4\"])\n\n Run inference with custom configuration, requires Streamlit version 1.29.0 or higher.\n >>> handle_yolo_solutions([\"inference\", \"model=yolo11n.pt\"])\n\n Notes:\n - Arguments can be provided in the format 'key=value' or as boolean flags\n - Available solutions are defined in SOLUTION_MAP with their respective classes and methods\n - If an invalid solution is provided, defaults to 'count' solution\n - Output videos are saved in 'runs/solution/{solution_name}' directory\n - For 'analytics' solution, frame numbers are tracked for generating analytical graphs\n - Video processing can be interrupted by pressing 'q'\n - Processes video frames sequentially and saves output in .avi format\n - If no source is specified, downloads and uses a default sample video\n - The inference solution will be launched using the 'streamlit run' command.\n - The Streamlit app file is located in the Ultralytics package directory.\n \"\"\"\n from ultralytics.solutions.config import SolutionConfig\n\n full_args_dict = vars(SolutionConfig()) # arguments dictionary\n overrides = {}\n\n # check dictionary alignment\n for arg in merge_equals_args(args):\n arg = arg.lstrip(\"-\").rstrip(\",\")\n if \"=\" in arg:\n try:\n k, v = parse_key_value_pair(arg)\n overrides[k] = v\n except (NameError, SyntaxError, ValueError, AssertionError) as e:\n check_dict_alignment(full_args_dict, {arg: \"\"}, e)\n elif arg in full_args_dict and isinstance(full_args_dict.get(arg), bool):\n overrides[arg] = True\n check_dict_alignment(full_args_dict, overrides) # dict alignment\n\n # Get solution name\n if not args:\n LOGGER.warning(\"No solution name provided. i.e `yolo solutions count`. Defaulting to 'count'.\")\n args = [\"count\"]\n if args[0] == \"help\":\n LOGGER.info(SOLUTIONS_HELP_MSG)\n return # Early return for 'help' case\n elif args[0] in SOLUTION_MAP:\n solution_name = args.pop(0) # Extract the solution name directly\n else:\n LOGGER.warning(\n f\"❌ '{args[0]}' is not a valid solution. 💡 Defaulting to 'count'.\\n\"\n f\"🚀 Available solutions: {', '.join(list(SOLUTION_MAP.keys())[:-1])}\\n\"\n )\n solution_name = \"count\" # Default for invalid solution\n\n if solution_name == \"inference\":\n checks.check_requirements(\"streamlit>=1.29.0\")\n LOGGER.info(\"💡 Loading Ultralytics live inference app...\")\n subprocess.run(\n [ # Run subprocess with Streamlit custom argument\n \"streamlit\",\n \"run\",\n str(ROOT / \"solutions/streamlit_inference.py\"),\n \"--server.headless\",\n \"true\",\n overrides.pop(\"model\", \"yolo11n.pt\"),\n ]\n )\n else:\n import cv2 # Only needed for cap and vw functionality\n\n from ultralytics import solutions\n\n solution = getattr(solutions, SOLUTION_MAP[solution_name])(is_cli=True, **overrides) # class i.e ObjectCounter\n\n cap = cv2.VideoCapture(solution.CFG[\"source\"]) # read the video file\n if solution_name != \"crop\":\n # extract width, height and fps of the video file, create save directory and initialize video writer\n w, h, fps = (\n int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS)\n )\n if solution_name == \"analytics\": # analytical graphs follow fixed shape for output i.e w=1920, h=1080\n w, h = 1280, 720\n save_dir = get_save_dir(SimpleNamespace(project=\"runs/solutions\", name=\"exp\", exist_ok=False))\n save_dir.mkdir(parents=True) # create the output directory i.e. runs/solutions/exp\n vw = cv2.VideoWriter(str(save_dir / f\"{solution_name}.avi\"), cv2.VideoWriter_fourcc(*\"mp4v\"), fps, (w, h))\n\n try: # Process video frames\n f_n = 0 # frame number, required for analytical graphs\n while cap.isOpened():\n success, frame = cap.read()\n if not success:\n break\n results = solution(frame, f_n := f_n + 1) if solution_name == \"analytics\" else solution(frame)\n if solution_name != \"crop\":\n vw.write(results.plot_im)\n if solution.CFG[\"show\"] and cv2.waitKey(1) & 0xFF == ord(\"q\"):\n break\n finally:\n cap.release()", "chunk_type": "function", "name": "handle_yolo_solutions", "file_path": "ultralytics\\ultralytics\\cfg\\__init__.py", "start_line": 633, "end_line": 742, "start_col": 0, "end_col": 25, "parent_name": null, "docstring": "Process YOLO solutions arguments and run the specified computer vision solutions pipeline.\n\nArgs:\n args (List[str]): Command-line arguments for configuring and running the Ultralytics YOLO\n solutions: https://docs.ultralytics.com/solutions/, It can include solution name, source,\n and other configuration parameters.\n\nExamples:\n Run people counting solution with default settings:\n >>> handle_yolo_solutions([\"count\"])\n\n Run analytics with custom configuration:\n >>> handle_yolo_solutions([\"analytics\", \"conf=0.25\", \"source=path/to/video.mp4\"])\n\n Run inference with custom configuration, requires Streamlit version 1.29.0 or higher.\n >>> handle_yolo_solutions([\"inference\", \"model=yolo11n.pt\"])\n\nNotes:\n - Arguments can be provided in the format 'key=value' or as boolean flags\n - Available solutions are defined in SOLUTION_MAP with their respective classes and methods\n - If an invalid solution is provided, defaults to 'count' solution\n - Output videos are saved in 'runs/solution/{solution_name}' directory\n - For 'analytics' solution, frame numbers are tracked for generating analytical graphs\n - Video processing can be interrupted by pressing 'q'\n - Processes video frames sequentially and saves output in .avi format\n - If no source is specified, downloads and uses a default sample video\n - The inference solution will be launched using the 'streamlit run' command.\n - The Streamlit app file is located in the Ultralytics package directory.", "parameters": [ "args: List[str]" ], "return_type": "None", "decorators": [], "complexity_score": 16, "dependencies": [ "shutil", "subprocess", "sys", "pathlib.Path", "types.SimpleNamespace", "typing.Any", "typing.Dict", "typing.List", "typing.Union", "ultralytics.__version__", "ultralytics.utils.ASSETS", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.DEFAULT_CFG_PATH", "ultralytics.utils.IS_VSCODE", "ultralytics.utils.LOGGER", "ultralytics.utils.RANK", "ultralytics.utils.ROOT", "ultralytics.utils.RUNS_DIR", "ultralytics.utils.SETTINGS", "ultralytics.utils.SETTINGS_FILE", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.YAML", "ultralytics.utils.IterableSimpleNamespace", "ultralytics.utils.checks", "ultralytics.utils.colorstr", "ultralytics.utils.deprecation_warn", "ultralytics.utils.vscode_msg", "ultralytics.hub", "ultralytics.solutions.config.SolutionConfig", "ultralytics.utils.files.increment_path", "difflib.get_close_matches", "cv2", "ultralytics.solutions", "ultralytics.RTDETR", "ultralytics.FastSAM", "ultralytics.SAM", "ultralytics.YOLO" ], "chunk_id": "function_handle_yolo_solutions_60beb7f1" }, { "content": "def parse_key_value_pair(pair: str = \"key=value\") -> tuple:\n \"\"\"\n Parse a key-value pair string into separate key and value components.\n\n Args:\n pair (str): A string containing a key-value pair in the format \"key=value\".\n\n Returns:\n key (str): The parsed key.\n value (str): The parsed value.\n\n Raises:\n AssertionError: If the value is missing or empty.\n\n Examples:\n >>> key, value = parse_key_value_pair(\"model=yolo11n.pt\")\n >>> print(f\"Key: {key}, Value: {value}\")\n Key: model, Value: yolo11n.pt\n\n >>> key, value = parse_key_value_pair(\"epochs=100\")\n >>> print(f\"Key: {key}, Value: {value}\")\n Key: epochs, Value: 100\n\n Notes:\n - The function splits the input string on the first '=' character.\n - Leading and trailing whitespace is removed from both key and value.\n - An assertion error is raised if the value is empty after stripping.\n \"\"\"\n k, v = pair.split(\"=\", 1) # split on first '=' sign\n k, v = k.strip(), v.strip() # remove spaces\n assert v, f\"missing '{k}' value\"\n return k, smart_value(v)", "chunk_type": "function", "name": "parse_key_value_pair", "file_path": "ultralytics\\ultralytics\\cfg\\__init__.py", "start_line": 745, "end_line": 776, "start_col": 0, "end_col": 28, "parent_name": null, "docstring": "Parse a key-value pair string into separate key and value components.\n\nArgs:\n pair (str): A string containing a key-value pair in the format \"key=value\".\n\nReturns:\n key (str): The parsed key.\n value (str): The parsed value.\n\nRaises:\n AssertionError: If the value is missing or empty.\n\nExamples:\n >>> key, value = parse_key_value_pair(\"model=yolo11n.pt\")\n >>> print(f\"Key: {key}, Value: {value}\")\n Key: model, Value: yolo11n.pt\n\n >>> key, value = parse_key_value_pair(\"epochs=100\")\n >>> print(f\"Key: {key}, Value: {value}\")\n Key: epochs, Value: 100\n\nNotes:\n - The function splits the input string on the first '=' character.\n - Leading and trailing whitespace is removed from both key and value.\n - An assertion error is raised if the value is empty after stripping.", "parameters": [ "pair: str" ], "return_type": "tuple", "decorators": [], "complexity_score": 1, "dependencies": [ "shutil", "subprocess", "sys", "pathlib.Path", "types.SimpleNamespace", "typing.Any", "typing.Dict", "typing.List", "typing.Union", "ultralytics.__version__", "ultralytics.utils.ASSETS", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.DEFAULT_CFG_PATH", "ultralytics.utils.IS_VSCODE", "ultralytics.utils.LOGGER", "ultralytics.utils.RANK", "ultralytics.utils.ROOT", "ultralytics.utils.RUNS_DIR", "ultralytics.utils.SETTINGS", "ultralytics.utils.SETTINGS_FILE", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.YAML", "ultralytics.utils.IterableSimpleNamespace", "ultralytics.utils.checks", "ultralytics.utils.colorstr", "ultralytics.utils.deprecation_warn", "ultralytics.utils.vscode_msg", "ultralytics.hub", "ultralytics.solutions.config.SolutionConfig", "ultralytics.utils.files.increment_path", "difflib.get_close_matches", "cv2", "ultralytics.solutions", "ultralytics.RTDETR", "ultralytics.FastSAM", "ultralytics.SAM", "ultralytics.YOLO" ], "chunk_id": "function_parse_key_value_pair_f04e255d" }, { "content": "def smart_value(v: str) -> Any:\n \"\"\"\n Convert a string representation of a value to its appropriate Python type.\n\n This function attempts to convert a given string into a Python object of the most appropriate type. It handles\n conversions to None, bool, int, float, and other types that can be evaluated safely.\n\n Args:\n v (str): The string representation of the value to be converted.\n\n Returns:\n (Any): The converted value. The type can be None, bool, int, float, or the original string if no conversion\n is applicable.\n\n Examples:\n >>> smart_value(\"42\")\n 42\n >>> smart_value(\"3.14\")\n 3.14\n >>> smart_value(\"True\")\n True\n >>> smart_value(\"None\")\n None\n >>> smart_value(\"some_string\")\n 'some_string'\n\n Notes:\n - The function uses a case-insensitive comparison for boolean and None values.\n - For other types, it attempts to use Python's eval() function, which can be unsafe if used on untrusted input.\n - If no conversion is possible, the original string is returned.\n \"\"\"\n v_lower = v.lower()\n if v_lower == \"none\":\n return None\n elif v_lower == \"true\":\n return True\n elif v_lower == \"false\":\n return False\n else:\n try:\n return eval(v)\n except Exception:\n return v", "chunk_type": "function", "name": "smart_value", "file_path": "ultralytics\\ultralytics\\cfg\\__init__.py", "start_line": 779, "end_line": 821, "start_col": 0, "end_col": 20, "parent_name": null, "docstring": "Convert a string representation of a value to its appropriate Python type.\n\nThis function attempts to convert a given string into a Python object of the most appropriate type. It handles\nconversions to None, bool, int, float, and other types that can be evaluated safely.\n\nArgs:\n v (str): The string representation of the value to be converted.\n\nReturns:\n (Any): The converted value. The type can be None, bool, int, float, or the original string if no conversion\n is applicable.\n\nExamples:\n >>> smart_value(\"42\")\n 42\n >>> smart_value(\"3.14\")\n 3.14\n >>> smart_value(\"True\")\n True\n >>> smart_value(\"None\")\n None\n >>> smart_value(\"some_string\")\n 'some_string'\n\nNotes:\n - The function uses a case-insensitive comparison for boolean and None values.\n - For other types, it attempts to use Python's eval() function, which can be unsafe if used on untrusted input.\n - If no conversion is possible, the original string is returned.", "parameters": [ "v: str" ], "return_type": "Any", "decorators": [], "complexity_score": 5, "dependencies": [ "shutil", "subprocess", "sys", "pathlib.Path", "types.SimpleNamespace", "typing.Any", "typing.Dict", "typing.List", "typing.Union", "ultralytics.__version__", "ultralytics.utils.ASSETS", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.DEFAULT_CFG_PATH", "ultralytics.utils.IS_VSCODE", "ultralytics.utils.LOGGER", "ultralytics.utils.RANK", "ultralytics.utils.ROOT", "ultralytics.utils.RUNS_DIR", "ultralytics.utils.SETTINGS", "ultralytics.utils.SETTINGS_FILE", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.YAML", "ultralytics.utils.IterableSimpleNamespace", "ultralytics.utils.checks", "ultralytics.utils.colorstr", "ultralytics.utils.deprecation_warn", "ultralytics.utils.vscode_msg", "ultralytics.hub", "ultralytics.solutions.config.SolutionConfig", "ultralytics.utils.files.increment_path", "difflib.get_close_matches", "cv2", "ultralytics.solutions", "ultralytics.RTDETR", "ultralytics.FastSAM", "ultralytics.SAM", "ultralytics.YOLO" ], "chunk_id": "function_smart_value_4f0fb7f9" }, { "content": "def entrypoint(debug: str = \"\") -> None:\n \"\"\"\n Ultralytics entrypoint function for parsing and executing command-line arguments.\n\n This function serves as the main entry point for the Ultralytics CLI, parsing command-line arguments and\n executing the corresponding tasks such as training, validation, prediction, exporting models, and more.\n\n Args:\n debug (str): Space-separated string of command-line arguments for debugging purposes.\n\n Examples:\n Train a detection model for 10 epochs with an initial learning_rate of 0.01:\n >>> entrypoint(\"train data=coco8.yaml model=yolo11n.pt epochs=10 lr0=0.01\")\n\n Predict a YouTube video using a pretrained segmentation model at image size 320:\n >>> entrypoint(\"predict model=yolo11n-seg.pt source='https://youtu.be/LNwODJXcvt4' imgsz=320\")\n\n Validate a pretrained detection model at batch-size 1 and image size 640:\n >>> entrypoint(\"val model=yolo11n.pt data=coco8.yaml batch=1 imgsz=640\")\n\n Notes:\n - If no arguments are passed, the function will display the usage help message.\n - For a list of all available commands and their arguments, see the provided help messages and the\n Ultralytics documentation at https://docs.ultralytics.com.\n \"\"\"\n args = (debug.split(\" \") if debug else ARGV)[1:]\n if not args: # no arguments passed\n LOGGER.info(CLI_HELP_MSG)\n return\n\n special = {\n \"help\": lambda: LOGGER.info(CLI_HELP_MSG),\n \"checks\": checks.collect_system_info,\n \"version\": lambda: LOGGER.info(__version__),\n \"settings\": lambda: handle_yolo_settings(args[1:]),\n \"cfg\": lambda: YAML.print(DEFAULT_CFG_PATH),\n \"hub\": lambda: handle_yolo_hub(args[1:]),\n \"login\": lambda: handle_yolo_hub(args),\n \"logout\": lambda: handle_yolo_hub(args),\n \"copy-cfg\": copy_default_cfg,\n \"solutions\": lambda: handle_yolo_solutions(args[1:]),\n }\n full_args_dict = {**DEFAULT_CFG_DICT, **{k: None for k in TASKS}, **{k: None for k in MODES}, **special}\n\n # Define common misuses of special commands, i.e. -h, -help, --help\n special.update({k[0]: v for k, v in special.items()}) # singular\n special.update({k[:-1]: v for k, v in special.items() if len(k) > 1 and k.endswith(\"s\")}) # singular\n special = {**special, **{f\"-{k}\": v for k, v in special.items()}, **{f\"--{k}\": v for k, v in special.items()}}\n\n overrides = {} # basic overrides, i.e. imgsz=320\n for a in merge_equals_args(args): # merge spaces around '=' sign\n if a.startswith(\"--\"):\n LOGGER.warning(f\"argument '{a}' does not require leading dashes '--', updating to '{a[2:]}'.\")\n a = a[2:]\n if a.endswith(\",\"):\n LOGGER.warning(f\"argument '{a}' does not require trailing comma ',', updating to '{a[:-1]}'.\")\n a = a[:-1]\n if \"=\" in a:\n try:\n k, v = parse_key_value_pair(a)\n if k == \"cfg\" and v is not None: # custom.yaml passed\n LOGGER.info(f\"Overriding {DEFAULT_CFG_PATH} with {v}\")\n overrides = {k: val for k, val in YAML.load(checks.check_yaml(v)).items() if k != \"cfg\"}\n else:\n overrides[k] = v\n except (NameError, SyntaxError, ValueError, AssertionError) as e:\n check_dict_alignment(full_args_dict, {a: \"\"}, e)\n\n elif a in TASKS:\n overrides[\"task\"] = a\n elif a in MODES:\n overrides[\"mode\"] = a\n elif a.lower() in special:\n special[a.lower()]()\n return\n elif a in DEFAULT_CFG_DICT and isinstance(DEFAULT_CFG_DICT[a], bool):\n overrides[a] = True # auto-True for default bool args, i.e. 'yolo show' sets show=True\n elif a in DEFAULT_CFG_DICT:\n raise SyntaxError(\n f\"'{colorstr('red', 'bold', a)}' is a valid YOLO argument but is missing an '=' sign \"\n f\"to set its value, i.e. try '{a}={DEFAULT_CFG_DICT[a]}'\\n{CLI_HELP_MSG}\"\n )\n else:\n check_dict_alignment(full_args_dict, {a: \"\"})\n\n # Check keys\n check_dict_alignment(full_args_dict, overrides)\n\n # Mode\n mode = overrides.get(\"mode\")\n if mode is None:\n mode = DEFAULT_CFG.mode or \"predict\"\n LOGGER.warning(f\"'mode' argument is missing. Valid modes are {list(MODES)}. Using default 'mode={mode}'.\")\n elif mode not in MODES:\n raise ValueError(f\"Invalid 'mode={mode}'. Valid modes are {list(MODES)}.\\n{CLI_HELP_MSG}\")\n\n # Task\n task = overrides.pop(\"task\", None)\n if task:\n if task not in TASKS:\n if task == \"track\":\n LOGGER.warning(\n f\"invalid 'task=track', setting 'task=detect' and 'mode=track'. Valid tasks are {list(TASKS)}.\\n{CLI_HELP_MSG}.\"\n )\n task, mode = \"detect\", \"track\"\n else:\n raise ValueError(f\"Invalid 'task={task}'. Valid tasks are {list(TASKS)}.\\n{CLI_HELP_MSG}\")\n if \"model\" not in overrides:\n overrides[\"model\"] = TASK2MODEL[task]\n\n # Model\n model = overrides.pop(\"model\", DEFAULT_CFG.model)\n if model is None:\n model = \"yolo11n.pt\"\n LOGGER.warning(f\"'model' argument is missing. Using default 'model={model}'.\")\n overrides[\"model\"] = model\n stem = Path(model).stem.lower()\n if \"rtdetr\" in stem: # guess architecture\n from ultralytics import RTDETR\n\n model = RTDETR(model) # no task argument\n elif \"fastsam\" in stem:\n from ultralytics import FastSAM\n\n model = FastSAM(model)\n elif \"sam_\" in stem or \"sam2_\" in stem or \"sam2.1_\" in stem:\n from ultralytics import SAM\n\n model = SAM(model)\n else:\n from ultralytics import YOLO\n\n model = YOLO(model, task=task)\n\n # Task Update\n if task != model.task:\n if task:\n LOGGER.warning(\n f\"conflicting 'task={task}' passed with 'task={model.task}' model. \"\n f\"Ignoring 'task={task}' and updating to 'task={model.task}' to match model.\"\n )\n task = model.task\n\n # Mode\n if mode in {\"predict\", \"track\"} and \"source\" not in overrides:\n overrides[\"source\"] = (\n \"https://ultralytics.com/images/boats.jpg\" if task == \"obb\" else DEFAULT_CFG.source or ASSETS\n )\n LOGGER.warning(f\"'source' argument is missing. Using default 'source={overrides['source']}'.\")\n elif mode in {\"train\", \"val\"}:\n if \"data\" not in overrides and \"resume\" not in overrides:\n overrides[\"data\"] = DEFAULT_CFG.data or TASK2DATA.get(task or DEFAULT_CFG.task, DEFAULT_CFG.data)\n LOGGER.warning(f\"'data' argument is missing. Using default 'data={overrides['data']}'.\")\n elif mode == \"export\":\n if \"format\" not in overrides:\n overrides[\"format\"] = DEFAULT_CFG.format or \"torchscript\"\n LOGGER.warning(f\"'format' argument is missing. Using default 'format={overrides['format']}'.\")\n\n # Run command in python\n getattr(model, mode)(**overrides) # default args from model\n\n # Show help\n LOGGER.info(f\"💡 Learn more at https://docs.ultralytics.com/modes/{mode}\")\n\n # Recommend VS Code extension\n if IS_VSCODE and SETTINGS.get(\"vscode_msg\", True):\n LOGGER.info(vscode_msg())", "chunk_type": "function", "name": "entrypoint", "file_path": "ultralytics\\ultralytics\\cfg\\__init__.py", "start_line": 824, "end_line": 990, "start_col": 0, "end_col": 33, "parent_name": null, "docstring": "Ultralytics entrypoint function for parsing and executing command-line arguments.\n\nThis function serves as the main entry point for the Ultralytics CLI, parsing command-line arguments and\nexecuting the corresponding tasks such as training, validation, prediction, exporting models, and more.\n\nArgs:\n debug (str): Space-separated string of command-line arguments for debugging purposes.\n\nExamples:\n Train a detection model for 10 epochs with an initial learning_rate of 0.01:\n >>> entrypoint(\"train data=coco8.yaml model=yolo11n.pt epochs=10 lr0=0.01\")\n\n Predict a YouTube video using a pretrained segmentation model at image size 320:\n >>> entrypoint(\"predict model=yolo11n-seg.pt source='https://youtu.be/LNwODJXcvt4' imgsz=320\")\n\n Validate a pretrained detection model at batch-size 1 and image size 640:\n >>> entrypoint(\"val model=yolo11n.pt data=coco8.yaml batch=1 imgsz=640\")\n\nNotes:\n - If no arguments are passed, the function will display the usage help message.\n - For a list of all available commands and their arguments, see the provided help messages and the\n Ultralytics documentation at https://docs.ultralytics.com.", "parameters": [ "debug: str" ], "return_type": "None", "decorators": [], "complexity_score": 38, "dependencies": [ "shutil", "subprocess", "sys", "pathlib.Path", "types.SimpleNamespace", "typing.Any", "typing.Dict", "typing.List", "typing.Union", "ultralytics.__version__", "ultralytics.utils.ASSETS", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.DEFAULT_CFG_PATH", "ultralytics.utils.IS_VSCODE", "ultralytics.utils.LOGGER", "ultralytics.utils.RANK", "ultralytics.utils.ROOT", "ultralytics.utils.RUNS_DIR", "ultralytics.utils.SETTINGS", "ultralytics.utils.SETTINGS_FILE", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.YAML", "ultralytics.utils.IterableSimpleNamespace", "ultralytics.utils.checks", "ultralytics.utils.colorstr", "ultralytics.utils.deprecation_warn", "ultralytics.utils.vscode_msg", "ultralytics.hub", "ultralytics.solutions.config.SolutionConfig", "ultralytics.utils.files.increment_path", "difflib.get_close_matches", "cv2", "ultralytics.solutions", "ultralytics.RTDETR", "ultralytics.FastSAM", "ultralytics.SAM", "ultralytics.YOLO" ], "chunk_id": "function_entrypoint_c9aa0bc3" }, { "content": "def copy_default_cfg() -> None:\n \"\"\"\n Copy the default configuration file and create a new one with '_copy' appended to its name.\n\n This function duplicates the existing default configuration file (DEFAULT_CFG_PATH) and saves it\n with '_copy' appended to its name in the current working directory. It provides a convenient way\n to create a custom configuration file based on the default settings.\n\n Examples:\n >>> copy_default_cfg()\n # Output: default.yaml copied to /path/to/current/directory/default_copy.yaml\n # Example YOLO command with this new custom cfg:\n # yolo cfg='/path/to/current/directory/default_copy.yaml' imgsz=320 batch=8\n\n Notes:\n - The new configuration file is created in the current working directory.\n - After copying, the function prints a message with the new file's location and an example\n YOLO command demonstrating how to use the new configuration file.\n - This function is useful for users who want to modify the default configuration without\n altering the original file.\n \"\"\"\n new_file = Path.cwd() / DEFAULT_CFG_PATH.name.replace(\".yaml\", \"_copy.yaml\")\n shutil.copy2(DEFAULT_CFG_PATH, new_file)\n LOGGER.info(\n f\"{DEFAULT_CFG_PATH} copied to {new_file}\\n\"\n f\"Example YOLO command with this new custom cfg:\\n yolo cfg='{new_file}' imgsz=320 batch=8\"\n )", "chunk_type": "function", "name": "copy_default_cfg", "file_path": "ultralytics\\ultralytics\\cfg\\__init__.py", "start_line": 994, "end_line": 1020, "start_col": 0, "end_col": 5, "parent_name": null, "docstring": "Copy the default configuration file and create a new one with '_copy' appended to its name.\n\nThis function duplicates the existing default configuration file (DEFAULT_CFG_PATH) and saves it\nwith '_copy' appended to its name in the current working directory. It provides a convenient way\nto create a custom configuration file based on the default settings.\n\nExamples:\n >>> copy_default_cfg()\n # Output: default.yaml copied to /path/to/current/directory/default_copy.yaml\n # Example YOLO command with this new custom cfg:\n # yolo cfg='/path/to/current/directory/default_copy.yaml' imgsz=320 batch=8\n\nNotes:\n - The new configuration file is created in the current working directory.\n - After copying, the function prints a message with the new file's location and an example\n YOLO command demonstrating how to use the new configuration file.\n - This function is useful for users who want to modify the default configuration without\n altering the original file.", "parameters": [], "return_type": "None", "decorators": [], "complexity_score": 1, "dependencies": [ "shutil", "subprocess", "sys", "pathlib.Path", "types.SimpleNamespace", "typing.Any", "typing.Dict", "typing.List", "typing.Union", "ultralytics.__version__", "ultralytics.utils.ASSETS", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.DEFAULT_CFG_PATH", "ultralytics.utils.IS_VSCODE", "ultralytics.utils.LOGGER", "ultralytics.utils.RANK", "ultralytics.utils.ROOT", "ultralytics.utils.RUNS_DIR", "ultralytics.utils.SETTINGS", "ultralytics.utils.SETTINGS_FILE", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.YAML", "ultralytics.utils.IterableSimpleNamespace", "ultralytics.utils.checks", "ultralytics.utils.colorstr", "ultralytics.utils.deprecation_warn", "ultralytics.utils.vscode_msg", "ultralytics.hub", "ultralytics.solutions.config.SolutionConfig", "ultralytics.utils.files.increment_path", "difflib.get_close_matches", "cv2", "ultralytics.solutions", "ultralytics.RTDETR", "ultralytics.FastSAM", "ultralytics.SAM", "ultralytics.YOLO" ], "chunk_id": "function_copy_default_cfg_fe328855" }, { "content": "from pathlib import Path", "chunk_type": "import", "name": "Path", "file_path": "ultralytics\\ultralytics\\data\\annotator.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Path_a2a6c2c0" }, { "content": "from typing import List, Optional, Union", "chunk_type": "import", "name": "List, Optional, Union", "file_path": "ultralytics\\ultralytics\\data\\annotator.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 40, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_List, Optional, Union_1e3be1f9" }, { "content": "from ultralytics import SAM, YOLO", "chunk_type": "import", "name": "SAM, YOLO", "file_path": "ultralytics\\ultralytics\\data\\annotator.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 33, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_SAM, YOLO_fc4ca142" }, { "content": "def auto_annotate(\n data: Union[str, Path],\n det_model: str = \"yolo11x.pt\",\n sam_model: str = \"sam_b.pt\",\n device: str = \"\",\n conf: float = 0.25,\n iou: float = 0.45,\n imgsz: int = 640,\n max_det: int = 300,\n classes: Optional[List[int]] = None,\n output_dir: Optional[Union[str, Path]] = None,\n) -> None:\n \"\"\"\n Automatically annotate images using a YOLO object detection model and a SAM segmentation model.\n\n This function processes images in a specified directory, detects objects using a YOLO model, and then generates\n segmentation masks using a SAM model. The resulting annotations are saved as text files in YOLO format.\n\n Args:\n data (str | Path): Path to a folder containing images to be annotated.\n det_model (str): Path or name of the pre-trained YOLO detection model.\n sam_model (str): Path or name of the pre-trained SAM segmentation model.\n device (str): Device to run the models on (e.g., 'cpu', 'cuda', '0'). Empty string for auto-selection.\n conf (float): Confidence threshold for detection model.\n iou (float): IoU threshold for filtering overlapping boxes in detection results.\n imgsz (int): Input image resize dimension.\n max_det (int): Maximum number of detections per image.\n classes (List[int], optional): Filter predictions to specified class IDs, returning only relevant detections.\n output_dir (str | Path, optional): Directory to save the annotated results. If None, creates a default\n directory based on the input data path.\n\n Examples:\n >>> from ultralytics.data.annotator import auto_annotate\n >>> auto_annotate(data=\"ultralytics/assets\", det_model=\"yolo11n.pt\", sam_model=\"mobile_sam.pt\")\n \"\"\"\n det_model = YOLO(det_model)\n sam_model = SAM(sam_model)\n\n data = Path(data)\n if not output_dir:\n output_dir = data.parent / f\"{data.stem}_auto_annotate_labels\"\n Path(output_dir).mkdir(exist_ok=True, parents=True)\n\n det_results = det_model(\n data, stream=True, device=device, conf=conf, iou=iou, imgsz=imgsz, max_det=max_det, classes=classes\n )\n\n for result in det_results:\n class_ids = result.boxes.cls.int().tolist() # Extract class IDs from detection results\n if class_ids:\n boxes = result.boxes.xyxy # Boxes object for bbox outputs\n sam_results = sam_model(result.orig_img, bboxes=boxes, verbose=False, save=False, device=device)\n segments = sam_results[0].masks.xyn\n\n with open(f\"{Path(output_dir) / Path(result.path).stem}.txt\", \"w\", encoding=\"utf-8\") as f:\n for i, s in enumerate(segments):\n if s.any():\n segment = map(str, s.reshape(-1).tolist())\n f.write(f\"{class_ids[i]} \" + \" \".join(segment) + \"\\n\")", "chunk_type": "function", "name": "auto_annotate", "file_path": "ultralytics\\ultralytics\\data\\annotator.py", "start_line": 9, "end_line": 67, "start_col": 0, "end_col": 78, "parent_name": null, "docstring": "Automatically annotate images using a YOLO object detection model and a SAM segmentation model.\n\nThis function processes images in a specified directory, detects objects using a YOLO model, and then generates\nsegmentation masks using a SAM model. The resulting annotations are saved as text files in YOLO format.\n\nArgs:\n data (str | Path): Path to a folder containing images to be annotated.\n det_model (str): Path or name of the pre-trained YOLO detection model.\n sam_model (str): Path or name of the pre-trained SAM segmentation model.\n device (str): Device to run the models on (e.g., 'cpu', 'cuda', '0'). Empty string for auto-selection.\n conf (float): Confidence threshold for detection model.\n iou (float): IoU threshold for filtering overlapping boxes in detection results.\n imgsz (int): Input image resize dimension.\n max_det (int): Maximum number of detections per image.\n classes (List[int], optional): Filter predictions to specified class IDs, returning only relevant detections.\n output_dir (str | Path, optional): Directory to save the annotated results. If None, creates a default\n directory based on the input data path.\n\nExamples:\n >>> from ultralytics.data.annotator import auto_annotate\n >>> auto_annotate(data=\"ultralytics/assets\", det_model=\"yolo11n.pt\", sam_model=\"mobile_sam.pt\")", "parameters": [ "data: Union[str, Path]", "det_model: str", "sam_model: str", "device: str", "conf: float", "iou: float", "imgsz: int", "max_det: int", "classes: Optional[List[int]]", "output_dir: Optional[Union[str, Path]]" ], "return_type": "None", "decorators": [], "complexity_score": 6, "dependencies": [ "pathlib.Path", "typing.List", "typing.Optional", "typing.Union", "ultralytics.SAM", "ultralytics.YOLO" ], "chunk_id": "function_auto_annotate_9805e429" }, { "content": "import math", "chunk_type": "import", "name": "math", "file_path": "ultralytics\\ultralytics\\data\\augment.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 11, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_math_35a0ced2" }, { "content": "import random", "chunk_type": "import", "name": "random", "file_path": "ultralytics\\ultralytics\\data\\augment.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 13, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_random_9c265621" }, { "content": "from copy import deepcopy", "chunk_type": "import", "name": "deepcopy", "file_path": "ultralytics\\ultralytics\\data\\augment.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 25, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_deepcopy_036cf453" }, { "content": "from typing import Any, Dict, List, Tuple, Union", "chunk_type": "import", "name": "Any, Dict, List, Tuple, Union", "file_path": "ultralytics\\ultralytics\\data\\augment.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 48, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Any, Dict, List, Tuple, Union_9ff08eb2" }, { "content": "import cv2", "chunk_type": "import", "name": "cv2", "file_path": "ultralytics\\ultralytics\\data\\augment.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 10, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_cv2_bccee3cd" }, { "content": "import numpy as np", "chunk_type": "import", "name": "numpy", "file_path": "ultralytics\\ultralytics\\data\\augment.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_numpy_ae1e4627" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\ultralytics\\data\\augment.py", "start_line": 10, "end_line": 10, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_f9dddff1" }, { "content": "from PIL import Image", "chunk_type": "import", "name": "Image", "file_path": "ultralytics\\ultralytics\\data\\augment.py", "start_line": 11, "end_line": 11, "start_col": 0, "end_col": 21, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Image_61fccbb2" }, { "content": "from torch.nn import functional as F", "chunk_type": "import", "name": "functional", "file_path": "ultralytics\\ultralytics\\data\\augment.py", "start_line": 12, "end_line": 12, "start_col": 0, "end_col": 36, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_functional_a720e409" }, { "content": "from ultralytics.data.utils import polygons2masks, polygons2masks_overlap", "chunk_type": "import", "name": "polygons2masks, polygons2masks_overlap", "file_path": "ultralytics\\ultralytics\\data\\augment.py", "start_line": 14, "end_line": 14, "start_col": 0, "end_col": 73, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_polygons2masks, polygons2masks_overlap_74dff7da" }, { "content": "from ultralytics.utils import LOGGER, IterableSimpleNamespace, colorstr", "chunk_type": "import", "name": "LOGGER, IterableSimpleNamespace, colorstr", "file_path": "ultralytics\\ultralytics\\data\\augment.py", "start_line": 15, "end_line": 15, "start_col": 0, "end_col": 71, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_LOGGER, IterableSimpleNamespace, colorstr_b65e386b" }, { "content": "from ultralytics.utils.checks import check_version", "chunk_type": "import", "name": "check_version", "file_path": "ultralytics\\ultralytics\\data\\augment.py", "start_line": 16, "end_line": 16, "start_col": 0, "end_col": 50, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_check_version_cdf3ca29" }, { "content": "from ultralytics.utils.instance import Instances", "chunk_type": "import", "name": "Instances", "file_path": "ultralytics\\ultralytics\\data\\augment.py", "start_line": 17, "end_line": 17, "start_col": 0, "end_col": 48, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Instances_ca1adc4c" }, { "content": "from ultralytics.utils.metrics import bbox_ioa", "chunk_type": "import", "name": "bbox_ioa", "file_path": "ultralytics\\ultralytics\\data\\augment.py", "start_line": 18, "end_line": 18, "start_col": 0, "end_col": 46, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_bbox_ioa_5fe6fc0a" }, { "content": "from ultralytics.utils.ops import segment2box, xywh2xyxy, xyxyxyxy2xywhr", "chunk_type": "import", "name": "segment2box, xywh2xyxy, xyxyxyxy2xywhr", "file_path": "ultralytics\\ultralytics\\data\\augment.py", "start_line": 19, "end_line": 19, "start_col": 0, "end_col": 72, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_segment2box, xywh2xyxy, xyxyxyxy2xywhr_32030909" }, { "content": "from ultralytics.utils.torch_utils import TORCHVISION_0_10, TORCHVISION_0_11, TORCHVISION_0_13", "chunk_type": "import", "name": "TORCHVISION_0_10, TORCHVISION_0_11, TORCHVISION_0_13", "file_path": "ultralytics\\ultralytics\\data\\augment.py", "start_line": 20, "end_line": 20, "start_col": 0, "end_col": 94, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_TORCHVISION_0_10, TORCHVISION_0_11, TORCHVISION_0_13_2967d938" }, { "content": "DEFAULT_MEAN = (0.0, 0.0, 0.0)", "chunk_type": "variable", "name": "DEFAULT_MEAN", "file_path": "ultralytics\\ultralytics\\data\\augment.py", "start_line": 22, "end_line": 22, "start_col": 0, "end_col": 30, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_DEFAULT_MEAN_ccf4028e" }, { "content": "DEFAULT_STD = (1.0, 1.0, 1.0)", "chunk_type": "variable", "name": "DEFAULT_STD", "file_path": "ultralytics\\ultralytics\\data\\augment.py", "start_line": 23, "end_line": 23, "start_col": 0, "end_col": 29, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_DEFAULT_STD_6e070859" }, { "content": "class BaseTransform:\n \"\"\"\n Base class for image transformations in the Ultralytics library.\n\n This class serves as a foundation for implementing various image processing operations, designed to be\n compatible with both classification and semantic segmentation tasks.\n\n Methods:\n apply_image: Apply image transformations to labels.\n apply_instances: Apply transformations to object instances in labels.\n apply_semantic: Apply semantic segmentation to an image.\n __call__: Apply all label transformations to an image, instances, and semantic masks.\n\n Examples:\n >>> transform = BaseTransform()\n >>> labels = {\"image\": np.array(...), \"instances\": [...], \"semantic\": np.array(...)}\n >>> transformed_labels = transform(labels)\n \"\"\"\n\n def __init__(self) -> None:\n \"\"\"\n Initialize the BaseTransform object.\n\n This constructor sets up the base transformation object, which can be extended for specific image\n processing tasks. It is designed to be compatible with both classification and semantic segmentation.\n\n Examples:\n >>> transform = BaseTransform()\n \"\"\"\n pass\n\n def apply_image(self, labels):\n \"\"\"\n Apply image transformations to labels.\n\n This method is intended to be overridden by subclasses to implement specific image transformation\n logic. In its base form, it returns the input labels unchanged.\n\n Args:\n labels (Any): The input labels to be transformed. The exact type and structure of labels may\n vary depending on the specific implementation.\n\n Returns:\n (Any): The transformed labels. In the base implementation, this is identical to the input.\n\n Examples:\n >>> transform = BaseTransform()\n >>> original_labels = [1, 2, 3]\n >>> transformed_labels = transform.apply_image(original_labels)\n >>> print(transformed_labels)\n [1, 2, 3]\n \"\"\"\n pass\n\n def apply_instances(self, labels):\n \"\"\"\n Apply transformations to object instances in labels.\n\n This method is responsible for applying various transformations to object instances within the given\n labels. It is designed to be overridden by subclasses to implement specific instance transformation\n logic.\n\n Args:\n labels (dict): A dictionary containing label information, including object instances.\n\n Returns:\n (dict): The modified labels dictionary with transformed object instances.\n\n Examples:\n >>> transform = BaseTransform()\n >>> labels = {\"instances\": Instances(xyxy=torch.rand(5, 4), cls=torch.randint(0, 80, (5,)))}\n >>> transformed_labels = transform.apply_instances(labels)\n \"\"\"\n pass\n\n def apply_semantic(self, labels):\n \"\"\"\n Apply semantic segmentation transformations to an image.\n\n This method is intended to be overridden by subclasses to implement specific semantic segmentation\n transformations. In its base form, it does not perform any operations.\n\n Args:\n labels (Any): The input labels or semantic segmentation mask to be transformed.\n\n Returns:\n (Any): The transformed semantic segmentation mask or labels.\n\n Examples:\n >>> transform = BaseTransform()\n >>> semantic_mask = np.zeros((100, 100), dtype=np.uint8)\n >>> transformed_mask = transform.apply_semantic(semantic_mask)\n \"\"\"\n pass\n\n def __call__(self, labels):\n \"\"\"\n Apply all label transformations to an image, instances, and semantic masks.\n\n This method orchestrates the application of various transformations defined in the BaseTransform class\n to the input labels. It sequentially calls the apply_image and apply_instances methods to process the\n image and object instances, respectively.\n\n Args:\n labels (dict): A dictionary containing image data and annotations. Expected keys include 'img' for\n the image data, and 'instances' for object instances.\n\n Returns:\n (dict): The input labels dictionary with transformed image and instances.\n\n Examples:\n >>> transform = BaseTransform()\n >>> labels = {\"img\": np.random.rand(640, 640, 3), \"instances\": []}\n >>> transformed_labels = transform(labels)\n \"\"\"\n self.apply_image(labels)\n self.apply_instances(labels)\n self.apply_semantic(labels)", "chunk_type": "class", "name": "BaseTransform", "file_path": "ultralytics\\ultralytics\\data\\augment.py", "start_line": 26, "end_line": 143, "start_col": 0, "end_col": 35, "parent_name": null, "docstring": "Base class for image transformations in the Ultralytics library.\n\nThis class serves as a foundation for implementing various image processing operations, designed to be\ncompatible with both classification and semantic segmentation tasks.\n\nMethods:\n apply_image: Apply image transformations to labels.\n apply_instances: Apply transformations to object instances in labels.\n apply_semantic: Apply semantic segmentation to an image.\n __call__: Apply all label transformations to an image, instances, and semantic masks.\n\nExamples:\n >>> transform = BaseTransform()\n >>> labels = {\"image\": np.array(...), \"instances\": [...], \"semantic\": np.array(...)}\n >>> transformed_labels = transform(labels)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "math", "random", "copy.deepcopy", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "typing.Union", "cv2", "numpy", "torch", "PIL.Image", "torch.nn.functional", "ultralytics.data.utils.polygons2masks", "ultralytics.data.utils.polygons2masks_overlap", "ultralytics.utils.LOGGER", "ultralytics.utils.IterableSimpleNamespace", "ultralytics.utils.colorstr", "ultralytics.utils.checks.check_version", "ultralytics.utils.instance.Instances", "ultralytics.utils.metrics.bbox_ioa", "ultralytics.utils.ops.segment2box", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xyxyxyxy2xywhr", "ultralytics.utils.torch_utils.TORCHVISION_0_10", "ultralytics.utils.torch_utils.TORCHVISION_0_11", "ultralytics.utils.torch_utils.TORCHVISION_0_13", "torchvision.transforms", "torchvision.transforms", "os", "albumentations" ], "chunk_id": "class_BaseTransform_60d9bfb5" }, { "content": "class Compose:\n \"\"\"\n A class for composing multiple image transformations.\n\n Attributes:\n transforms (List[Callable]): A list of transformation functions to be applied sequentially.\n\n Methods:\n __call__: Apply a series of transformations to input data.\n append: Append a new transform to the existing list of transforms.\n insert: Insert a new transform at a specified index in the list of transforms.\n __getitem__: Retrieve a specific transform or a set of transforms using indexing.\n __setitem__: Set a specific transform or a set of transforms using indexing.\n tolist: Convert the list of transforms to a standard Python list.\n\n Examples:\n >>> transforms = [RandomFlip(), RandomPerspective(30)]\n >>> compose = Compose(transforms)\n >>> transformed_data = compose(data)\n >>> compose.append(CenterCrop((224, 224)))\n >>> compose.insert(0, RandomFlip())\n \"\"\"\n\n def __init__(self, transforms):\n \"\"\"\n Initialize the Compose object with a list of transforms.\n\n Args:\n transforms (List[Callable]): A list of callable transform objects to be applied sequentially.\n\n Examples:\n >>> from ultralytics.data.augment import Compose, RandomHSV, RandomFlip\n >>> transforms = [RandomHSV(), RandomFlip()]\n >>> compose = Compose(transforms)\n \"\"\"\n self.transforms = transforms if isinstance(transforms, list) else [transforms]\n\n def __call__(self, data):\n \"\"\"\n Apply a series of transformations to input data.\n\n This method sequentially applies each transformation in the Compose object's transforms to the input data.\n\n Args:\n data (Any): The input data to be transformed. This can be of any type, depending on the\n transformations in the list.\n\n Returns:\n (Any): The transformed data after applying all transformations in sequence.\n\n Examples:\n >>> transforms = [Transform1(), Transform2(), Transform3()]\n >>> compose = Compose(transforms)\n >>> transformed_data = compose(input_data)\n \"\"\"\n for t in self.transforms:\n data = t(data)\n return data\n\n def append(self, transform):\n \"\"\"\n Append a new transform to the existing list of transforms.\n\n Args:\n transform (BaseTransform): The transformation to be added to the composition.\n\n Examples:\n >>> compose = Compose([RandomFlip(), RandomPerspective()])\n >>> compose.append(RandomHSV())\n \"\"\"\n self.transforms.append(transform)\n\n def insert(self, index, transform):\n \"\"\"\n Insert a new transform at a specified index in the existing list of transforms.\n\n Args:\n index (int): The index at which to insert the new transform.\n transform (BaseTransform): The transform object to be inserted.\n\n Examples:\n >>> compose = Compose([Transform1(), Transform2()])\n >>> compose.insert(1, Transform3())\n >>> len(compose.transforms)\n 3\n \"\"\"\n self.transforms.insert(index, transform)\n\n def __getitem__(self, index: Union[list, int]) -> \"Compose\":\n \"\"\"\n Retrieve a specific transform or a set of transforms using indexing.\n\n Args:\n index (int | List[int]): Index or list of indices of the transforms to retrieve.\n\n Returns:\n (Compose): A new Compose object containing the selected transform(s).\n\n Raises:\n AssertionError: If the index is not of type int or list.\n\n Examples:\n >>> transforms = [RandomFlip(), RandomPerspective(10), RandomHSV(0.5, 0.5, 0.5)]\n >>> compose = Compose(transforms)\n >>> single_transform = compose[1] # Returns a Compose object with only RandomPerspective\n >>> multiple_transforms = compose[0:2] # Returns a Compose object with RandomFlip and RandomPerspective\n \"\"\"\n assert isinstance(index, (int, list)), f\"The indices should be either list or int type but got {type(index)}\"\n return Compose([self.transforms[i] for i in index]) if isinstance(index, list) else self.transforms[index]\n\n def __setitem__(self, index: Union[list, int], value: Union[list, int]) -> None:\n \"\"\"\n Set one or more transforms in the composition using indexing.\n\n Args:\n index (int | List[int]): Index or list of indices to set transforms at.\n value (Any | List[Any]): Transform or list of transforms to set at the specified index(es).\n\n Raises:\n AssertionError: If index type is invalid, value type doesn't match index type, or index is out of range.\n\n Examples:\n >>> compose = Compose([Transform1(), Transform2(), Transform3()])\n >>> compose[1] = NewTransform() # Replace second transform\n >>> compose[0:2] = [NewTransform1(), NewTransform2()] # Replace first two transforms\n \"\"\"\n assert isinstance(index, (int, list)), f\"The indices should be either list or int type but got {type(index)}\"\n if isinstance(index, list):\n assert isinstance(value, list), (\n f\"The indices should be the same type as values, but got {type(index)} and {type(value)}\"\n )\n if isinstance(index, int):\n index, value = [index], [value]\n for i, v in zip(index, value):\n assert i < len(self.transforms), f\"list index {i} out of range {len(self.transforms)}.\"\n self.transforms[i] = v\n\n def tolist(self):\n \"\"\"\n Convert the list of transforms to a standard Python list.\n\n Returns:\n (list): A list containing all the transform objects in the Compose instance.\n\n Examples:\n >>> transforms = [RandomFlip(), RandomPerspective(10), CenterCrop()]\n >>> compose = Compose(transforms)\n >>> transform_list = compose.tolist()\n >>> print(len(transform_list))\n 3\n \"\"\"\n return self.transforms\n\n def __repr__(self):\n \"\"\"\n Return a string representation of the Compose object.\n\n Returns:\n (str): A string representation of the Compose object, including the list of transforms.\n\n Examples:\n >>> transforms = [RandomFlip(), RandomPerspective(degrees=10, translate=0.1, scale=0.1)]\n >>> compose = Compose(transforms)\n >>> print(compose)\n Compose([\n RandomFlip(),\n RandomPerspective(degrees=10, translate=0.1, scale=0.1)\n ])\n \"\"\"\n return f\"{self.__class__.__name__}({', '.join([f'{t}' for t in self.transforms])})\"", "chunk_type": "class", "name": "Compose", "file_path": "ultralytics\\ultralytics\\data\\augment.py", "start_line": 146, "end_line": 315, "start_col": 0, "end_col": 91, "parent_name": null, "docstring": "A class for composing multiple image transformations.\n\nAttributes:\n transforms (List[Callable]): A list of transformation functions to be applied sequentially.\n\nMethods:\n __call__: Apply a series of transformations to input data.\n append: Append a new transform to the existing list of transforms.\n insert: Insert a new transform at a specified index in the list of transforms.\n __getitem__: Retrieve a specific transform or a set of transforms using indexing.\n __setitem__: Set a specific transform or a set of transforms using indexing.\n tolist: Convert the list of transforms to a standard Python list.\n\nExamples:\n >>> transforms = [RandomFlip(), RandomPerspective(30)]\n >>> compose = Compose(transforms)\n >>> transformed_data = compose(data)\n >>> compose.append(CenterCrop((224, 224)))\n >>> compose.insert(0, RandomFlip())", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "math", "random", "copy.deepcopy", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "typing.Union", "cv2", "numpy", "torch", "PIL.Image", "torch.nn.functional", "ultralytics.data.utils.polygons2masks", "ultralytics.data.utils.polygons2masks_overlap", "ultralytics.utils.LOGGER", "ultralytics.utils.IterableSimpleNamespace", "ultralytics.utils.colorstr", "ultralytics.utils.checks.check_version", "ultralytics.utils.instance.Instances", "ultralytics.utils.metrics.bbox_ioa", "ultralytics.utils.ops.segment2box", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xyxyxyxy2xywhr", "ultralytics.utils.torch_utils.TORCHVISION_0_10", "ultralytics.utils.torch_utils.TORCHVISION_0_11", "ultralytics.utils.torch_utils.TORCHVISION_0_13", "torchvision.transforms", "torchvision.transforms", "os", "albumentations" ], "chunk_id": "class_Compose_48cd0349" }, { "content": "class BaseMixTransform:\n \"\"\"\n Base class for mix transformations like Cutmix, MixUp and Mosaic.\n\n This class provides a foundation for implementing mix transformations on datasets. It handles the\n probability-based application of transforms and manages the mixing of multiple images and labels.\n\n Attributes:\n dataset (Any): The dataset object containing images and labels.\n pre_transform (Callable | None): Optional transform to apply before mixing.\n p (float): Probability of applying the mix transformation.\n\n Methods:\n __call__: Apply the mix transformation to the input labels.\n _mix_transform: Abstract method to be implemented by subclasses for specific mix operations.\n get_indexes: Abstract method to get indexes of images to be mixed.\n _update_label_text: Update label text for mixed images.\n\n Examples:\n >>> class CustomMixTransform(BaseMixTransform):\n ... def _mix_transform(self, labels):\n ... # Implement custom mix logic here\n ... return labels\n ...\n ... def get_indexes(self):\n ... return [random.randint(0, len(self.dataset) - 1) for _ in range(3)]\n >>> dataset = YourDataset()\n >>> transform = CustomMixTransform(dataset, p=0.5)\n >>> mixed_labels = transform(original_labels)\n \"\"\"\n\n def __init__(self, dataset, pre_transform=None, p=0.0) -> None:\n \"\"\"\n Initialize the BaseMixTransform object for mix transformations like CutMix, MixUp and Mosaic.\n\n This class serves as a base for implementing mix transformations in image processing pipelines.\n\n Args:\n dataset (Any): The dataset object containing images and labels for mixing.\n pre_transform (Callable | None): Optional transform to apply before mixing.\n p (float): Probability of applying the mix transformation. Should be in the range [0.0, 1.0].\n\n Examples:\n >>> dataset = YOLODataset(\"path/to/data\")\n >>> pre_transform = Compose([RandomFlip(), RandomPerspective()])\n >>> mix_transform = BaseMixTransform(dataset, pre_transform, p=0.5)\n \"\"\"\n self.dataset = dataset\n self.pre_transform = pre_transform\n self.p = p\n\n def __call__(self, labels: Dict[str, Any]) -> Dict[str, Any]:\n \"\"\"\n Apply pre-processing transforms and cutmix/mixup/mosaic transforms to labels data.\n\n This method determines whether to apply the mix transform based on a probability factor. If applied, it\n selects additional images, applies pre-transforms if specified, and then performs the mix transform.\n\n Args:\n labels (Dict[str, Any]): A dictionary containing label data for an image.\n\n Returns:\n (Dict[str, Any]): The transformed labels dictionary, which may include mixed data from other images.\n\n Examples:\n >>> transform = BaseMixTransform(dataset, pre_transform=None, p=0.5)\n >>> result = transform({\"image\": img, \"bboxes\": boxes, \"cls\": classes})\n \"\"\"\n if random.uniform(0, 1) > self.p:\n return labels\n\n # Get index of one or three other images\n indexes = self.get_indexes()\n if isinstance(indexes, int):\n indexes = [indexes]\n\n # Get images information will be used for Mosaic, CutMix or MixUp\n mix_labels = [self.dataset.get_image_and_label(i) for i in indexes]\n\n if self.pre_transform is not None:\n for i, data in enumerate(mix_labels):\n mix_labels[i] = self.pre_transform(data)\n labels[\"mix_labels\"] = mix_labels\n\n # Update cls and texts\n labels = self._update_label_text(labels)\n # Mosaic, CutMix or MixUp\n labels = self._mix_transform(labels)\n labels.pop(\"mix_labels\", None)\n return labels\n\n def _mix_transform(self, labels: Dict[str, Any]):\n \"\"\"\n Apply CutMix, MixUp or Mosaic augmentation to the label dictionary.\n\n This method should be implemented by subclasses to perform specific mix transformations like CutMix, MixUp or\n Mosaic. It modifies the input label dictionary in-place with the augmented data.\n\n Args:\n labels (Dict[str, Any]): A dictionary containing image and label data. Expected to have a 'mix_labels' key\n with a list of additional image and label data for mixing.\n\n Returns:\n (Dict[str, Any]): The modified labels dictionary with augmented data after applying the mix transform.\n\n Examples:\n >>> transform = BaseMixTransform(dataset)\n >>> labels = {\"image\": img, \"bboxes\": boxes, \"mix_labels\": [{\"image\": img2, \"bboxes\": boxes2}]}\n >>> augmented_labels = transform._mix_transform(labels)\n \"\"\"\n raise NotImplementedError\n\n def get_indexes(self):\n \"\"\"\n Get a list of shuffled indexes for mosaic augmentation.\n\n Returns:\n (List[int]): A list of shuffled indexes from the dataset.\n\n Examples:\n >>> transform = BaseMixTransform(dataset)\n >>> indexes = transform.get_indexes()\n >>> print(indexes) # [3, 18, 7, 2]\n \"\"\"\n return random.randint(0, len(self.dataset) - 1)\n\n @staticmethod\n def _update_label_text(labels: Dict[str, Any]) -> Dict[str, Any]:\n \"\"\"\n Update label text and class IDs for mixed labels in image augmentation.\n\n This method processes the 'texts' and 'cls' fields of the input labels dictionary and any mixed labels,\n creating a unified set of text labels and updating class IDs accordingly.\n\n Args:\n labels (Dict[str, Any]): A dictionary containing label information, including 'texts' and 'cls' fields,\n and optionally a 'mix_labels' field with additional label dictionaries.\n\n Returns:\n (Dict[str, Any]): The updated labels dictionary with unified text labels and updated class IDs.\n\n Examples:\n >>> labels = {\n ... \"texts\": [[\"cat\"], [\"dog\"]],\n ... \"cls\": torch.tensor([[0], [1]]),\n ... \"mix_labels\": [{\"texts\": [[\"bird\"], [\"fish\"]], \"cls\": torch.tensor([[0], [1]])}],\n ... }\n >>> updated_labels = self._update_label_text(labels)\n >>> print(updated_labels[\"texts\"])\n [['cat'], ['dog'], ['bird'], ['fish']]\n >>> print(updated_labels[\"cls\"])\n tensor([[0],\n [1]])\n >>> print(updated_labels[\"mix_labels\"][0][\"cls\"])\n tensor([[2],\n [3]])\n \"\"\"\n if \"texts\" not in labels:\n return labels\n\n mix_texts = sum([labels[\"texts\"]] + [x[\"texts\"] for x in labels[\"mix_labels\"]], [])\n mix_texts = list({tuple(x) for x in mix_texts})\n text2id = {text: i for i, text in enumerate(mix_texts)}\n\n for label in [labels] + labels[\"mix_labels\"]:\n for i, cls in enumerate(label[\"cls\"].squeeze(-1).tolist()):\n text = label[\"texts\"][int(cls)]\n label[\"cls\"][i] = text2id[tuple(text)]\n label[\"texts\"] = mix_texts\n return labels", "chunk_type": "class", "name": "BaseMixTransform", "file_path": "ultralytics\\ultralytics\\data\\augment.py", "start_line": 318, "end_line": 487, "start_col": 0, "end_col": 21, "parent_name": null, "docstring": "Base class for mix transformations like Cutmix, MixUp and Mosaic.\n\nThis class provides a foundation for implementing mix transformations on datasets. It handles the\nprobability-based application of transforms and manages the mixing of multiple images and labels.\n\nAttributes:\n dataset (Any): The dataset object containing images and labels.\n pre_transform (Callable | None): Optional transform to apply before mixing.\n p (float): Probability of applying the mix transformation.\n\nMethods:\n __call__: Apply the mix transformation to the input labels.\n _mix_transform: Abstract method to be implemented by subclasses for specific mix operations.\n get_indexes: Abstract method to get indexes of images to be mixed.\n _update_label_text: Update label text for mixed images.\n\nExamples:\n >>> class CustomMixTransform(BaseMixTransform):\n ... def _mix_transform(self, labels):\n ... # Implement custom mix logic here\n ... return labels\n ...\n ... def get_indexes(self):\n ... return [random.randint(0, len(self.dataset) - 1) for _ in range(3)]\n >>> dataset = YourDataset()\n >>> transform = CustomMixTransform(dataset, p=0.5)\n >>> mixed_labels = transform(original_labels)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "math", "random", "copy.deepcopy", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "typing.Union", "cv2", "numpy", "torch", "PIL.Image", "torch.nn.functional", "ultralytics.data.utils.polygons2masks", "ultralytics.data.utils.polygons2masks_overlap", "ultralytics.utils.LOGGER", "ultralytics.utils.IterableSimpleNamespace", "ultralytics.utils.colorstr", "ultralytics.utils.checks.check_version", "ultralytics.utils.instance.Instances", "ultralytics.utils.metrics.bbox_ioa", "ultralytics.utils.ops.segment2box", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xyxyxyxy2xywhr", "ultralytics.utils.torch_utils.TORCHVISION_0_10", "ultralytics.utils.torch_utils.TORCHVISION_0_11", "ultralytics.utils.torch_utils.TORCHVISION_0_13", "torchvision.transforms", "torchvision.transforms", "os", "albumentations" ], "chunk_id": "class_BaseMixTransform_0bda0703" }, { "content": "class Mosaic(BaseMixTransform):\n \"\"\"\n Mosaic augmentation for image datasets.\n\n This class performs mosaic augmentation by combining multiple (4 or 9) images into a single mosaic image.\n The augmentation is applied to a dataset with a given probability.\n\n Attributes:\n dataset: The dataset on which the mosaic augmentation is applied.\n imgsz (int): Image size (height and width) after mosaic pipeline of a single image.\n p (float): Probability of applying the mosaic augmentation. Must be in the range 0-1.\n n (int): The grid size, either 4 (for 2x2) or 9 (for 3x3).\n border (Tuple[int, int]): Border size for width and height.\n\n Methods:\n get_indexes: Return a list of random indexes from the dataset.\n _mix_transform: Apply mixup transformation to the input image and labels.\n _mosaic3: Create a 1x3 image mosaic.\n _mosaic4: Create a 2x2 image mosaic.\n _mosaic9: Create a 3x3 image mosaic.\n _update_labels: Update labels with padding.\n _cat_labels: Concatenate labels and clips mosaic border instances.\n\n Examples:\n >>> from ultralytics.data.augment import Mosaic\n >>> dataset = YourDataset(...) # Your image dataset\n >>> mosaic_aug = Mosaic(dataset, imgsz=640, p=0.5, n=4)\n >>> augmented_labels = mosaic_aug(original_labels)\n \"\"\"\n\n def __init__(self, dataset, imgsz: int = 640, p: float = 1.0, n: int = 4):\n \"\"\"\n Initialize the Mosaic augmentation object.\n\n This class performs mosaic augmentation by combining multiple (4 or 9) images into a single mosaic image.\n The augmentation is applied to a dataset with a given probability.\n\n Args:\n dataset (Any): The dataset on which the mosaic augmentation is applied.\n imgsz (int): Image size (height and width) after mosaic pipeline of a single image.\n p (float): Probability of applying the mosaic augmentation. Must be in the range 0-1.\n n (int): The grid size, either 4 (for 2x2) or 9 (for 3x3).\n\n Examples:\n >>> from ultralytics.data.augment import Mosaic\n >>> dataset = YourDataset(...)\n >>> mosaic_aug = Mosaic(dataset, imgsz=640, p=0.5, n=4)\n \"\"\"\n assert 0 <= p <= 1.0, f\"The probability should be in range [0, 1], but got {p}.\"\n assert n in {4, 9}, \"grid must be equal to 4 or 9.\"\n super().__init__(dataset=dataset, p=p)\n self.imgsz = imgsz\n self.border = (-imgsz // 2, -imgsz // 2) # width, height\n self.n = n\n self.buffer_enabled = self.dataset.cache != \"ram\"\n\n def get_indexes(self):\n \"\"\"\n Return a list of random indexes from the dataset for mosaic augmentation.\n\n This method selects random image indexes either from a buffer or from the entire dataset, depending on\n the 'buffer' parameter. It is used to choose images for creating mosaic augmentations.\n\n Returns:\n (List[int]): A list of random image indexes. The length of the list is n-1, where n is the number\n of images used in the mosaic (either 3 or 8, depending on whether n is 4 or 9).\n\n Examples:\n >>> mosaic = Mosaic(dataset, imgsz=640, p=1.0, n=4)\n >>> indexes = mosaic.get_indexes()\n >>> print(len(indexes)) # Output: 3\n \"\"\"\n if self.buffer_enabled: # select images from buffer\n return random.choices(list(self.dataset.buffer), k=self.n - 1)\n else: # select any images\n return [random.randint(0, len(self.dataset) - 1) for _ in range(self.n - 1)]\n\n def _mix_transform(self, labels: Dict[str, Any]) -> Dict[str, Any]:\n \"\"\"\n Apply mosaic augmentation to the input image and labels.\n\n This method combines multiple images (3, 4, or 9) into a single mosaic image based on the 'n' attribute.\n It ensures that rectangular annotations are not present and that there are other images available for\n mosaic augmentation.\n\n Args:\n labels (Dict[str, Any]): A dictionary containing image data and annotations. Expected keys include:\n - 'rect_shape': Should be None as rect and mosaic are mutually exclusive.\n - 'mix_labels': A list of dictionaries containing data for other images to be used in the mosaic.\n\n Returns:\n (Dict[str, Any]): A dictionary containing the mosaic-augmented image and updated annotations.\n\n Raises:\n AssertionError: If 'rect_shape' is not None or if 'mix_labels' is empty.\n\n Examples:\n >>> mosaic = Mosaic(dataset, imgsz=640, p=1.0, n=4)\n >>> augmented_data = mosaic._mix_transform(labels)\n \"\"\"\n assert labels.get(\"rect_shape\", None) is None, \"rect and mosaic are mutually exclusive.\"\n assert len(labels.get(\"mix_labels\", [])), \"There are no other images for mosaic augment.\"\n return (\n self._mosaic3(labels) if self.n == 3 else self._mosaic4(labels) if self.n == 4 else self._mosaic9(labels)\n ) # This code is modified for mosaic3 method.\n\n def _mosaic3(self, labels: Dict[str, Any]) -> Dict[str, Any]:\n \"\"\"\n Create a 1x3 image mosaic by combining three images.\n\n This method arranges three images in a horizontal layout, with the main image in the center and two\n additional images on either side. It's part of the Mosaic augmentation technique used in object detection.\n\n Args:\n labels (Dict[str, Any]): A dictionary containing image and label information for the main (center) image.\n Must include 'img' key with the image array, and 'mix_labels' key with a list of two\n dictionaries containing information for the side images.\n\n Returns:\n (Dict[str, Any]): A dictionary with the mosaic image and updated labels. Keys include:\n - 'img' (np.ndarray): The mosaic image array with shape (H, W, C).\n - Other keys from the input labels, updated to reflect the new image dimensions.\n\n Examples:\n >>> mosaic = Mosaic(dataset, imgsz=640, p=1.0, n=3)\n >>> labels = {\n ... \"img\": np.random.rand(480, 640, 3),\n ... \"mix_labels\": [{\"img\": np.random.rand(480, 640, 3)} for _ in range(2)],\n ... }\n >>> result = mosaic._mosaic3(labels)\n >>> print(result[\"img\"].shape)\n (640, 640, 3)\n \"\"\"\n mosaic_labels = []\n s = self.imgsz\n for i in range(3):\n labels_patch = labels if i == 0 else labels[\"mix_labels\"][i - 1]\n # Load image\n img = labels_patch[\"img\"]\n h, w = labels_patch.pop(\"resized_shape\")\n\n # Place img in img3\n if i == 0: # center\n img3 = np.full((s * 3, s * 3, img.shape[2]), 114, dtype=np.uint8) # base image with 3 tiles\n h0, w0 = h, w\n c = s, s, s + w, s + h # xmin, ymin, xmax, ymax (base) coordinates\n elif i == 1: # right\n c = s + w0, s, s + w0 + w, s + h\n elif i == 2: # left\n c = s - w, s + h0 - h, s, s + h0\n\n padw, padh = c[:2]\n x1, y1, x2, y2 = (max(x, 0) for x in c) # allocate coordinates\n\n img3[y1:y2, x1:x2] = img[y1 - padh :, x1 - padw :] # img3[ymin:ymax, xmin:xmax]\n # hp, wp = h, w # height, width previous for next iteration\n\n # Labels assuming imgsz*2 mosaic size\n labels_patch = self._update_labels(labels_patch, padw + self.border[0], padh + self.border[1])\n mosaic_labels.append(labels_patch)\n final_labels = self._cat_labels(mosaic_labels)\n\n final_labels[\"img\"] = img3[-self.border[0] : self.border[0], -self.border[1] : self.border[1]]\n return final_labels\n\n def _mosaic4(self, labels: Dict[str, Any]) -> Dict[str, Any]:\n \"\"\"\n Create a 2x2 image mosaic from four input images.\n\n This method combines four images into a single mosaic image by placing them in a 2x2 grid. It also\n updates the corresponding labels for each image in the mosaic.\n\n Args:\n labels (Dict[str, Any]): A dictionary containing image data and labels for the base image (index 0) and three\n additional images (indices 1-3) in the 'mix_labels' key.\n\n Returns:\n (Dict[str, Any]): A dictionary containing the mosaic image and updated labels. The 'img' key contains the mosaic\n image as a numpy array, and other keys contain the combined and adjusted labels for all four images.\n\n Examples:\n >>> mosaic = Mosaic(dataset, imgsz=640, p=1.0, n=4)\n >>> labels = {\n ... \"img\": np.random.rand(480, 640, 3),\n ... \"mix_labels\": [{\"img\": np.random.rand(480, 640, 3)} for _ in range(3)],\n ... }\n >>> result = mosaic._mosaic4(labels)\n >>> assert result[\"img\"].shape == (1280, 1280, 3)\n \"\"\"\n mosaic_labels = []\n s = self.imgsz\n yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.border) # mosaic center x, y\n for i in range(4):\n labels_patch = labels if i == 0 else labels[\"mix_labels\"][i - 1]\n # Load image\n img = labels_patch[\"img\"]\n h, w = labels_patch.pop(\"resized_shape\")\n\n # Place img in img4\n if i == 0: # top left\n img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles\n x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image)\n x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image)\n elif i == 1: # top right\n x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc\n x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h\n elif i == 2: # bottom left\n x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)\n x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)\n elif i == 3: # bottom right\n x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)\n x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)\n\n img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]\n padw = x1a - x1b\n padh = y1a - y1b\n\n labels_patch = self._update_labels(labels_patch, padw, padh)\n mosaic_labels.append(labels_patch)\n final_labels = self._cat_labels(mosaic_labels)\n final_labels[\"img\"] = img4\n return final_labels\n\n def _mosaic9(self, labels: Dict[str, Any]) -> Dict[str, Any]:\n \"\"\"\n Create a 3x3 image mosaic from the input image and eight additional images.\n\n This method combines nine images into a single mosaic image. The input image is placed at the center,\n and eight additional images from the dataset are placed around it in a 3x3 grid pattern.\n\n Args:\n labels (Dict[str, Any]): A dictionary containing the input image and its associated labels. It should have\n the following keys:\n - 'img' (np.ndarray): The input image.\n - 'resized_shape' (Tuple[int, int]): The shape of the resized image (height, width).\n - 'mix_labels' (List[Dict]): A list of dictionaries containing information for the additional\n eight images, each with the same structure as the input labels.\n\n Returns:\n (Dict[str, Any]): A dictionary containing the mosaic image and updated labels. It includes the following keys:\n - 'img' (np.ndarray): The final mosaic image.\n - Other keys from the input labels, updated to reflect the new mosaic arrangement.\n\n Examples:\n >>> mosaic = Mosaic(dataset, imgsz=640, p=1.0, n=9)\n >>> input_labels = dataset[0]\n >>> mosaic_result = mosaic._mosaic9(input_labels)\n >>> mosaic_image = mosaic_result[\"img\"]\n \"\"\"\n mosaic_labels = []\n s = self.imgsz\n hp, wp = -1, -1 # height, width previous\n for i in range(9):\n labels_patch = labels if i == 0 else labels[\"mix_labels\"][i - 1]\n # Load image\n img = labels_patch[\"img\"]\n h, w = labels_patch.pop(\"resized_shape\")\n\n # Place img in img9\n if i == 0: # center\n img9 = np.full((s * 3, s * 3, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles\n h0, w0 = h, w\n c = s, s, s + w, s + h # xmin, ymin, xmax, ymax (base) coordinates\n elif i == 1: # top\n c = s, s - h, s + w, s\n elif i == 2: # top right\n c = s + wp, s - h, s + wp + w, s\n elif i == 3: # right\n c = s + w0, s, s + w0 + w, s + h\n elif i == 4: # bottom right\n c = s + w0, s + hp, s + w0 + w, s + hp + h\n elif i == 5: # bottom\n c = s + w0 - w, s + h0, s + w0, s + h0 + h\n elif i == 6: # bottom left\n c = s + w0 - wp - w, s + h0, s + w0 - wp, s + h0 + h\n elif i == 7: # left\n c = s - w, s + h0 - h, s, s + h0\n elif i == 8: # top left\n c = s - w, s + h0 - hp - h, s, s + h0 - hp\n\n padw, padh = c[:2]\n x1, y1, x2, y2 = (max(x, 0) for x in c) # allocate coordinates\n\n # Image\n img9[y1:y2, x1:x2] = img[y1 - padh :, x1 - padw :] # img9[ymin:ymax, xmin:xmax]\n hp, wp = h, w # height, width previous for next iteration\n\n # Labels assuming imgsz*2 mosaic size\n labels_patch = self._update_labels(labels_patch, padw + self.border[0], padh + self.border[1])\n mosaic_labels.append(labels_patch)\n final_labels = self._cat_labels(mosaic_labels)\n\n final_labels[\"img\"] = img9[-self.border[0] : self.border[0], -self.border[1] : self.border[1]]\n return final_labels\n\n @staticmethod\n def _update_labels(labels, padw: int, padh: int) -> Dict[str, Any]:\n \"\"\"\n Update label coordinates with padding values.\n\n This method adjusts the bounding box coordinates of object instances in the labels by adding padding\n values. It also denormalizes the coordinates if they were previously normalized.\n\n Args:\n labels (Dict[str, Any]): A dictionary containing image and instance information.\n padw (int): Padding width to be added to the x-coordinates.\n padh (int): Padding height to be added to the y-coordinates.\n\n Returns:\n (dict): Updated labels dictionary with adjusted instance coordinates.\n\n Examples:\n >>> labels = {\"img\": np.zeros((100, 100, 3)), \"instances\": Instances(...)}\n >>> padw, padh = 50, 50\n >>> updated_labels = Mosaic._update_labels(labels, padw, padh)\n \"\"\"\n nh, nw = labels[\"img\"].shape[:2]\n labels[\"instances\"].convert_bbox(format=\"xyxy\")\n labels[\"instances\"].denormalize(nw, nh)\n labels[\"instances\"].add_padding(padw, padh)\n return labels\n\n def _cat_labels(self, mosaic_labels: List[Dict[str, Any]]) -> Dict[str, Any]:\n \"\"\"\n Concatenate and process labels for mosaic augmentation.\n\n This method combines labels from multiple images used in mosaic augmentation, clips instances to the\n mosaic border, and removes zero-area boxes.\n\n Args:\n mosaic_labels (List[Dict[str, Any]]): A list of label dictionaries for each image in the mosaic.\n\n Returns:\n (Dict[str, Any]): A dictionary containing concatenated and processed labels for the mosaic image, including:\n - im_file (str): File path of the first image in the mosaic.\n - ori_shape (Tuple[int, int]): Original shape of the first image.\n - resized_shape (Tuple[int, int]): Shape of the mosaic image (imgsz * 2, imgsz * 2).\n - cls (np.ndarray): Concatenated class labels.\n - instances (Instances): Concatenated instance annotations.\n - mosaic_border (Tuple[int, int]): Mosaic border size.\n - texts (List[str], optional): Text labels if present in the original labels.\n\n Examples:\n >>> mosaic = Mosaic(dataset, imgsz=640)\n >>> mosaic_labels = [{\"cls\": np.array([0, 1]), \"instances\": Instances(...)} for _ in range(4)]\n >>> result = mosaic._cat_labels(mosaic_labels)\n >>> print(result.keys())\n dict_keys(['im_file', 'ori_shape', 'resized_shape', 'cls', 'instances', 'mosaic_border'])\n \"\"\"\n if len(mosaic_labels) == 0:\n return {}\n cls = []\n instances = []\n imgsz = self.imgsz * 2 # mosaic imgsz\n for labels in mosaic_labels:\n cls.append(labels[\"cls\"])\n instances.append(labels[\"instances\"])\n # Final labels\n final_labels = {\n \"im_file\": mosaic_labels[0][\"im_file\"],\n \"ori_shape\": mosaic_labels[0][\"ori_shape\"],\n \"resized_shape\": (imgsz, imgsz),\n \"cls\": np.concatenate(cls, 0),\n \"instances\": Instances.concatenate(instances, axis=0),\n \"mosaic_border\": self.border,\n }\n final_labels[\"instances\"].clip(imgsz, imgsz)\n good = final_labels[\"instances\"].remove_zero_area_boxes()\n final_labels[\"cls\"] = final_labels[\"cls\"][good]\n if \"texts\" in mosaic_labels[0]:\n final_labels[\"texts\"] = mosaic_labels[0][\"texts\"]\n return final_labels", "chunk_type": "class", "name": "Mosaic", "file_path": "ultralytics\\ultralytics\\data\\augment.py", "start_line": 490, "end_line": 861, "start_col": 0, "end_col": 27, "parent_name": null, "docstring": "Mosaic augmentation for image datasets.\n\nThis class performs mosaic augmentation by combining multiple (4 or 9) images into a single mosaic image.\nThe augmentation is applied to a dataset with a given probability.\n\nAttributes:\n dataset: The dataset on which the mosaic augmentation is applied.\n imgsz (int): Image size (height and width) after mosaic pipeline of a single image.\n p (float): Probability of applying the mosaic augmentation. Must be in the range 0-1.\n n (int): The grid size, either 4 (for 2x2) or 9 (for 3x3).\n border (Tuple[int, int]): Border size for width and height.\n\nMethods:\n get_indexes: Return a list of random indexes from the dataset.\n _mix_transform: Apply mixup transformation to the input image and labels.\n _mosaic3: Create a 1x3 image mosaic.\n _mosaic4: Create a 2x2 image mosaic.\n _mosaic9: Create a 3x3 image mosaic.\n _update_labels: Update labels with padding.\n _cat_labels: Concatenate labels and clips mosaic border instances.\n\nExamples:\n >>> from ultralytics.data.augment import Mosaic\n >>> dataset = YourDataset(...) # Your image dataset\n >>> mosaic_aug = Mosaic(dataset, imgsz=640, p=0.5, n=4)\n >>> augmented_labels = mosaic_aug(original_labels)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "math", "random", "copy.deepcopy", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "typing.Union", "cv2", "numpy", "torch", "PIL.Image", "torch.nn.functional", "ultralytics.data.utils.polygons2masks", "ultralytics.data.utils.polygons2masks_overlap", "ultralytics.utils.LOGGER", "ultralytics.utils.IterableSimpleNamespace", "ultralytics.utils.colorstr", "ultralytics.utils.checks.check_version", "ultralytics.utils.instance.Instances", "ultralytics.utils.metrics.bbox_ioa", "ultralytics.utils.ops.segment2box", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xyxyxyxy2xywhr", "ultralytics.utils.torch_utils.TORCHVISION_0_10", "ultralytics.utils.torch_utils.TORCHVISION_0_11", "ultralytics.utils.torch_utils.TORCHVISION_0_13", "torchvision.transforms", "torchvision.transforms", "os", "albumentations", "BaseMixTransform" ], "chunk_id": "class_Mosaic_6a082acc" }, { "content": "class MixUp(BaseMixTransform):\n \"\"\"\n Apply MixUp augmentation to image datasets.\n\n This class implements the MixUp augmentation technique as described in the paper [mixup: Beyond Empirical Risk\n Minimization](https://arxiv.org/abs/1710.09412). MixUp combines two images and their labels using a random weight.\n\n Attributes:\n dataset (Any): The dataset to which MixUp augmentation will be applied.\n pre_transform (Callable | None): Optional transform to apply before MixUp.\n p (float): Probability of applying MixUp augmentation.\n\n Methods:\n _mix_transform: Apply MixUp augmentation to the input labels.\n\n Examples:\n >>> from ultralytics.data.augment import MixUp\n >>> dataset = YourDataset(...) # Your image dataset\n >>> mixup = MixUp(dataset, p=0.5)\n >>> augmented_labels = mixup(original_labels)\n \"\"\"\n\n def __init__(self, dataset, pre_transform=None, p: float = 0.0) -> None:\n \"\"\"\n Initialize the MixUp augmentation object.\n\n MixUp is an image augmentation technique that combines two images by taking a weighted sum of their pixel\n values and labels. This implementation is designed for use with the Ultralytics YOLO framework.\n\n Args:\n dataset (Any): The dataset to which MixUp augmentation will be applied.\n pre_transform (Callable | None): Optional transform to apply to images before MixUp.\n p (float): Probability of applying MixUp augmentation to an image. Must be in the range [0, 1].\n\n Examples:\n >>> from ultralytics.data.dataset import YOLODataset\n >>> dataset = YOLODataset(\"path/to/data.yaml\")\n >>> mixup = MixUp(dataset, pre_transform=None, p=0.5)\n \"\"\"\n super().__init__(dataset=dataset, pre_transform=pre_transform, p=p)\n\n def _mix_transform(self, labels: Dict[str, Any]) -> Dict[str, Any]:\n \"\"\"\n Apply MixUp augmentation to the input labels.\n\n This method implements the MixUp augmentation technique as described in the paper\n \"mixup: Beyond Empirical Risk Minimization\" (https://arxiv.org/abs/1710.09412).\n\n Args:\n labels (Dict[str, Any]): A dictionary containing the original image and label information.\n\n Returns:\n (Dict[str, Any]): A dictionary containing the mixed-up image and combined label information.\n\n Examples:\n >>> mixer = MixUp(dataset)\n >>> mixed_labels = mixer._mix_transform(labels)\n \"\"\"\n r = np.random.beta(32.0, 32.0) # mixup ratio, alpha=beta=32.0\n labels2 = labels[\"mix_labels\"][0]\n labels[\"img\"] = (labels[\"img\"] * r + labels2[\"img\"] * (1 - r)).astype(np.uint8)\n labels[\"instances\"] = Instances.concatenate([labels[\"instances\"], labels2[\"instances\"]], axis=0)\n labels[\"cls\"] = np.concatenate([labels[\"cls\"], labels2[\"cls\"]], 0)\n return labels", "chunk_type": "class", "name": "MixUp", "file_path": "ultralytics\\ultralytics\\data\\augment.py", "start_line": 864, "end_line": 927, "start_col": 0, "end_col": 21, "parent_name": null, "docstring": "Apply MixUp augmentation to image datasets.\n\nThis class implements the MixUp augmentation technique as described in the paper [mixup: Beyond Empirical Risk\nMinimization](https://arxiv.org/abs/1710.09412). MixUp combines two images and their labels using a random weight.\n\nAttributes:\n dataset (Any): The dataset to which MixUp augmentation will be applied.\n pre_transform (Callable | None): Optional transform to apply before MixUp.\n p (float): Probability of applying MixUp augmentation.\n\nMethods:\n _mix_transform: Apply MixUp augmentation to the input labels.\n\nExamples:\n >>> from ultralytics.data.augment import MixUp\n >>> dataset = YourDataset(...) # Your image dataset\n >>> mixup = MixUp(dataset, p=0.5)\n >>> augmented_labels = mixup(original_labels)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "math", "random", "copy.deepcopy", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "typing.Union", "cv2", "numpy", "torch", "PIL.Image", "torch.nn.functional", "ultralytics.data.utils.polygons2masks", "ultralytics.data.utils.polygons2masks_overlap", "ultralytics.utils.LOGGER", "ultralytics.utils.IterableSimpleNamespace", "ultralytics.utils.colorstr", "ultralytics.utils.checks.check_version", "ultralytics.utils.instance.Instances", "ultralytics.utils.metrics.bbox_ioa", "ultralytics.utils.ops.segment2box", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xyxyxyxy2xywhr", "ultralytics.utils.torch_utils.TORCHVISION_0_10", "ultralytics.utils.torch_utils.TORCHVISION_0_11", "ultralytics.utils.torch_utils.TORCHVISION_0_13", "torchvision.transforms", "torchvision.transforms", "os", "albumentations", "BaseMixTransform" ], "chunk_id": "class_MixUp_50a7d3c6" }, { "content": "class CutMix(BaseMixTransform):\n \"\"\"\n Apply CutMix augmentation to image datasets as described in the paper https://arxiv.org/abs/1905.04899.\n\n CutMix combines two images by replacing a random rectangular region of one image with the corresponding region from another image,\n and adjusts the labels proportionally to the area of the mixed region.\n\n Attributes:\n dataset (Any): The dataset to which CutMix augmentation will be applied.\n pre_transform (Callable | None): Optional transform to apply before CutMix.\n p (float): Probability of applying CutMix augmentation.\n beta (float): Beta distribution parameter for sampling the mixing ratio.\n num_areas (int): Number of areas to try to cut and mix.\n\n Methods:\n _mix_transform: Apply CutMix augmentation to the input labels.\n _rand_bbox: Generate random bounding box coordinates for the cut region.\n\n Examples:\n >>> from ultralytics.data.augment import CutMix\n >>> dataset = YourDataset(...) # Your image dataset\n >>> cutmix = CutMix(dataset, p=0.5)\n >>> augmented_labels = cutmix(original_labels)\n \"\"\"\n\n def __init__(self, dataset, pre_transform=None, p: float = 0.0, beta: float = 1.0, num_areas: int = 3) -> None:\n \"\"\"\n Initialize the CutMix augmentation object.\n\n Args:\n dataset (Any): The dataset to which CutMix augmentation will be applied.\n pre_transform (Callable | None): Optional transform to apply before CutMix.\n p (float): Probability of applying CutMix augmentation.\n beta (float): Beta distribution parameter for sampling the mixing ratio.\n num_areas (int): Number of areas to try to cut and mix.\n \"\"\"\n super().__init__(dataset=dataset, pre_transform=pre_transform, p=p)\n self.beta = beta\n self.num_areas = num_areas\n\n def _rand_bbox(self, width: int, height: int) -> Tuple[int, int, int, int]:\n \"\"\"\n Generate random bounding box coordinates for the cut region.\n\n Args:\n width (int): Width of the image.\n height (int): Height of the image.\n\n Returns:\n (Tuple[int]): (x1, y1, x2, y2) coordinates of the bounding box.\n \"\"\"\n # Sample mixing ratio from Beta distribution\n lam = np.random.beta(self.beta, self.beta)\n\n cut_ratio = np.sqrt(1.0 - lam)\n cut_w = int(width * cut_ratio)\n cut_h = int(height * cut_ratio)\n\n # Random center\n cx = np.random.randint(width)\n cy = np.random.randint(height)\n\n # Bounding box coordinates\n x1 = np.clip(cx - cut_w // 2, 0, width)\n y1 = np.clip(cy - cut_h // 2, 0, height)\n x2 = np.clip(cx + cut_w // 2, 0, width)\n y2 = np.clip(cy + cut_h // 2, 0, height)\n\n return x1, y1, x2, y2\n\n def _mix_transform(self, labels: Dict[str, Any]) -> Dict[str, Any]:\n \"\"\"\n Apply CutMix augmentation to the input labels.\n\n Args:\n labels (Dict[str, Any]): A dictionary containing the original image and label information.\n\n Returns:\n (Dict[str, Any]): A dictionary containing the mixed image and adjusted labels.\n\n Examples:\n >>> cutter = CutMix(dataset)\n >>> mixed_labels = cutter._mix_transform(labels)\n \"\"\"\n # Get a random second image\n h, w = labels[\"img\"].shape[:2]\n\n cut_areas = np.asarray([self._rand_bbox(w, h) for _ in range(self.num_areas)], dtype=np.float32)\n ioa1 = bbox_ioa(cut_areas, labels[\"instances\"].bboxes) # (self.num_areas, num_boxes)\n idx = np.nonzero(ioa1.sum(axis=1) <= 0)[0]\n if len(idx) == 0:\n return labels\n\n labels2 = labels.pop(\"mix_labels\")[0]\n area = cut_areas[np.random.choice(idx)] # randomly select one\n ioa2 = bbox_ioa(area[None], labels2[\"instances\"].bboxes).squeeze(0)\n indexes2 = np.nonzero(ioa2 >= (0.01 if len(labels[\"instances\"].segments) else 0.1))[0]\n if len(indexes2) == 0:\n return labels\n\n instances2 = labels2[\"instances\"][indexes2]\n instances2.convert_bbox(\"xyxy\")\n instances2.denormalize(w, h)\n\n # Apply CutMix\n x1, y1, x2, y2 = area.astype(np.int32)\n labels[\"img\"][y1:y2, x1:x2] = labels2[\"img\"][y1:y2, x1:x2]\n\n # Restrain instances2 to the random bounding border\n instances2.add_padding(-x1, -y1)\n instances2.clip(x2 - x1, y2 - y1)\n instances2.add_padding(x1, y1)\n\n labels[\"cls\"] = np.concatenate([labels[\"cls\"], labels2[\"cls\"][indexes2]], axis=0)\n labels[\"instances\"] = Instances.concatenate([labels[\"instances\"], instances2], axis=0)\n return labels", "chunk_type": "class", "name": "CutMix", "file_path": "ultralytics\\ultralytics\\data\\augment.py", "start_line": 930, "end_line": 1045, "start_col": 0, "end_col": 21, "parent_name": null, "docstring": "Apply CutMix augmentation to image datasets as described in the paper https://arxiv.org/abs/1905.04899.\n\nCutMix combines two images by replacing a random rectangular region of one image with the corresponding region from another image,\nand adjusts the labels proportionally to the area of the mixed region.\n\nAttributes:\n dataset (Any): The dataset to which CutMix augmentation will be applied.\n pre_transform (Callable | None): Optional transform to apply before CutMix.\n p (float): Probability of applying CutMix augmentation.\n beta (float): Beta distribution parameter for sampling the mixing ratio.\n num_areas (int): Number of areas to try to cut and mix.\n\nMethods:\n _mix_transform: Apply CutMix augmentation to the input labels.\n _rand_bbox: Generate random bounding box coordinates for the cut region.\n\nExamples:\n >>> from ultralytics.data.augment import CutMix\n >>> dataset = YourDataset(...) # Your image dataset\n >>> cutmix = CutMix(dataset, p=0.5)\n >>> augmented_labels = cutmix(original_labels)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "math", "random", "copy.deepcopy", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "typing.Union", "cv2", "numpy", "torch", "PIL.Image", "torch.nn.functional", "ultralytics.data.utils.polygons2masks", "ultralytics.data.utils.polygons2masks_overlap", "ultralytics.utils.LOGGER", "ultralytics.utils.IterableSimpleNamespace", "ultralytics.utils.colorstr", "ultralytics.utils.checks.check_version", "ultralytics.utils.instance.Instances", "ultralytics.utils.metrics.bbox_ioa", "ultralytics.utils.ops.segment2box", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xyxyxyxy2xywhr", "ultralytics.utils.torch_utils.TORCHVISION_0_10", "ultralytics.utils.torch_utils.TORCHVISION_0_11", "ultralytics.utils.torch_utils.TORCHVISION_0_13", "torchvision.transforms", "torchvision.transforms", "os", "albumentations", "BaseMixTransform" ], "chunk_id": "class_CutMix_0e01e339" }, { "content": "class RandomPerspective:\n \"\"\"\n Implement random perspective and affine transformations on images and corresponding annotations.\n\n This class applies random rotations, translations, scaling, shearing, and perspective transformations\n to images and their associated bounding boxes, segments, and keypoints. It can be used as part of an\n augmentation pipeline for object detection and instance segmentation tasks.\n\n Attributes:\n degrees (float): Maximum absolute degree range for random rotations.\n translate (float): Maximum translation as a fraction of the image size.\n scale (float): Scaling factor range, e.g., scale=0.1 means 0.9-1.1.\n shear (float): Maximum shear angle in degrees.\n perspective (float): Perspective distortion factor.\n border (Tuple[int, int]): Mosaic border size as (x, y).\n pre_transform (Callable | None): Optional transform to apply before the random perspective.\n\n Methods:\n affine_transform: Apply affine transformations to the input image.\n apply_bboxes: Transform bounding boxes using the affine matrix.\n apply_segments: Transform segments and generate new bounding boxes.\n apply_keypoints: Transform keypoints using the affine matrix.\n __call__: Apply the random perspective transformation to images and annotations.\n box_candidates: Filter transformed bounding boxes based on size and aspect ratio.\n\n Examples:\n >>> transform = RandomPerspective(degrees=10, translate=0.1, scale=0.1, shear=10)\n >>> image = np.random.randint(0, 255, (640, 640, 3), dtype=np.uint8)\n >>> labels = {\"img\": image, \"cls\": np.array([0, 1]), \"instances\": Instances(...)}\n >>> result = transform(labels)\n >>> transformed_image = result[\"img\"]\n >>> transformed_instances = result[\"instances\"]\n \"\"\"\n\n def __init__(\n self,\n degrees: float = 0.0,\n translate: float = 0.1,\n scale: float = 0.5,\n shear: float = 0.0,\n perspective: float = 0.0,\n border: Tuple[int, int] = (0, 0),\n pre_transform=None,\n ):\n \"\"\"\n Initialize RandomPerspective object with transformation parameters.\n\n This class implements random perspective and affine transformations on images and corresponding bounding boxes,\n segments, and keypoints. Transformations include rotation, translation, scaling, and shearing.\n\n Args:\n degrees (float): Degree range for random rotations.\n translate (float): Fraction of total width and height for random translation.\n scale (float): Scaling factor interval, e.g., a scale factor of 0.5 allows a resize between 50%-150%.\n shear (float): Shear intensity (angle in degrees).\n perspective (float): Perspective distortion factor.\n border (Tuple[int, int]): Tuple specifying mosaic border (top/bottom, left/right).\n pre_transform (Callable | None): Function/transform to apply to the image before starting the random\n transformation.\n\n Examples:\n >>> transform = RandomPerspective(degrees=10.0, translate=0.1, scale=0.5, shear=5.0)\n >>> result = transform(labels) # Apply random perspective to labels\n \"\"\"\n self.degrees = degrees\n self.translate = translate\n self.scale = scale\n self.shear = shear\n self.perspective = perspective\n self.border = border # mosaic border\n self.pre_transform = pre_transform\n\n def affine_transform(self, img: np.ndarray, border: Tuple[int, int]) -> Tuple[np.ndarray, np.ndarray, float]:\n \"\"\"\n Apply a sequence of affine transformations centered around the image center.\n\n This function performs a series of geometric transformations on the input image, including\n translation, perspective change, rotation, scaling, and shearing. The transformations are\n applied in a specific order to maintain consistency.\n\n Args:\n img (np.ndarray): Input image to be transformed.\n border (Tuple[int, int]): Border dimensions for the transformed image.\n\n Returns:\n img (np.ndarray): Transformed image.\n M (np.ndarray): 3x3 transformation matrix.\n s (float): Scale factor applied during the transformation.\n\n Examples:\n >>> import numpy as np\n >>> img = np.random.rand(100, 100, 3)\n >>> border = (10, 10)\n >>> transformed_img, matrix, scale = affine_transform(img, border)\n \"\"\"\n # Center\n C = np.eye(3, dtype=np.float32)\n\n C[0, 2] = -img.shape[1] / 2 # x translation (pixels)\n C[1, 2] = -img.shape[0] / 2 # y translation (pixels)\n\n # Perspective\n P = np.eye(3, dtype=np.float32)\n P[2, 0] = random.uniform(-self.perspective, self.perspective) # x perspective (about y)\n P[2, 1] = random.uniform(-self.perspective, self.perspective) # y perspective (about x)\n\n # Rotation and Scale\n R = np.eye(3, dtype=np.float32)\n a = random.uniform(-self.degrees, self.degrees)\n # a += random.choice([-180, -90, 0, 90]) # add 90deg rotations to small rotations\n s = random.uniform(1 - self.scale, 1 + self.scale)\n # s = 2 ** random.uniform(-scale, scale)\n R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s)\n\n # Shear\n S = np.eye(3, dtype=np.float32)\n S[0, 1] = math.tan(random.uniform(-self.shear, self.shear) * math.pi / 180) # x shear (deg)\n S[1, 0] = math.tan(random.uniform(-self.shear, self.shear) * math.pi / 180) # y shear (deg)\n\n # Translation\n T = np.eye(3, dtype=np.float32)\n T[0, 2] = random.uniform(0.5 - self.translate, 0.5 + self.translate) * self.size[0] # x translation (pixels)\n T[1, 2] = random.uniform(0.5 - self.translate, 0.5 + self.translate) * self.size[1] # y translation (pixels)\n\n # Combined rotation matrix\n M = T @ S @ R @ P @ C # order of operations (right to left) is IMPORTANT\n # Affine image\n if (border[0] != 0) or (border[1] != 0) or (M != np.eye(3)).any(): # image changed\n if self.perspective:\n img = cv2.warpPerspective(img, M, dsize=self.size, borderValue=(114, 114, 114))\n else: # affine\n img = cv2.warpAffine(img, M[:2], dsize=self.size, borderValue=(114, 114, 114))\n if img.ndim == 2:\n img = img[..., None]\n return img, M, s\n\n def apply_bboxes(self, bboxes: np.ndarray, M: np.ndarray) -> np.ndarray:\n \"\"\"\n Apply affine transformation to bounding boxes.\n\n This function applies an affine transformation to a set of bounding boxes using the provided\n transformation matrix.\n\n Args:\n bboxes (np.ndarray): Bounding boxes in xyxy format with shape (N, 4), where N is the number\n of bounding boxes.\n M (np.ndarray): Affine transformation matrix with shape (3, 3).\n\n Returns:\n (np.ndarray): Transformed bounding boxes in xyxy format with shape (N, 4).\n\n Examples:\n >>> bboxes = torch.tensor([[10, 10, 20, 20], [30, 30, 40, 40]])\n >>> M = torch.eye(3)\n >>> transformed_bboxes = apply_bboxes(bboxes, M)\n \"\"\"\n n = len(bboxes)\n if n == 0:\n return bboxes\n\n xy = np.ones((n * 4, 3), dtype=bboxes.dtype)\n xy[:, :2] = bboxes[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1\n xy = xy @ M.T # transform\n xy = (xy[:, :2] / xy[:, 2:3] if self.perspective else xy[:, :2]).reshape(n, 8) # perspective rescale or affine\n\n # Create new boxes\n x = xy[:, [0, 2, 4, 6]]\n y = xy[:, [1, 3, 5, 7]]\n return np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1)), dtype=bboxes.dtype).reshape(4, n).T\n\n def apply_segments(self, segments: np.ndarray, M: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:\n \"\"\"\n Apply affine transformations to segments and generate new bounding boxes.\n\n This function applies affine transformations to input segments and generates new bounding boxes based on\n the transformed segments. It clips the transformed segments to fit within the new bounding boxes.\n\n Args:\n segments (np.ndarray): Input segments with shape (N, M, 2), where N is the number of segments and M is the\n number of points in each segment.\n M (np.ndarray): Affine transformation matrix with shape (3, 3).\n\n Returns:\n bboxes (np.ndarray): New bounding boxes with shape (N, 4) in xyxy format.\n segments (np.ndarray): Transformed and clipped segments with shape (N, M, 2).\n\n Examples:\n >>> segments = np.random.rand(10, 500, 2) # 10 segments with 500 points each\n >>> M = np.eye(3) # Identity transformation matrix\n >>> new_bboxes, new_segments = apply_segments(segments, M)\n \"\"\"\n n, num = segments.shape[:2]\n if n == 0:\n return [], segments\n\n xy = np.ones((n * num, 3), dtype=segments.dtype)\n segments = segments.reshape(-1, 2)\n xy[:, :2] = segments\n xy = xy @ M.T # transform\n xy = xy[:, :2] / xy[:, 2:3]\n segments = xy.reshape(n, -1, 2)\n bboxes = np.stack([segment2box(xy, self.size[0], self.size[1]) for xy in segments], 0)\n segments[..., 0] = segments[..., 0].clip(bboxes[:, 0:1], bboxes[:, 2:3])\n segments[..., 1] = segments[..., 1].clip(bboxes[:, 1:2], bboxes[:, 3:4])\n return bboxes, segments\n\n def apply_keypoints(self, keypoints: np.ndarray, M: np.ndarray) -> np.ndarray:\n \"\"\"\n Apply affine transformation to keypoints.\n\n This method transforms the input keypoints using the provided affine transformation matrix. It handles\n perspective rescaling if necessary and updates the visibility of keypoints that fall outside the image\n boundaries after transformation.\n\n Args:\n keypoints (np.ndarray): Array of keypoints with shape (N, 17, 3), where N is the number of instances,\n 17 is the number of keypoints per instance, and 3 represents (x, y, visibility).\n M (np.ndarray): 3x3 affine transformation matrix.\n\n Returns:\n (np.ndarray): Transformed keypoints array with the same shape as input (N, 17, 3).\n\n Examples:\n >>> random_perspective = RandomPerspective()\n >>> keypoints = np.random.rand(5, 17, 3) # 5 instances, 17 keypoints each\n >>> M = np.eye(3) # Identity transformation\n >>> transformed_keypoints = random_perspective.apply_keypoints(keypoints, M)\n \"\"\"\n n, nkpt = keypoints.shape[:2]\n if n == 0:\n return keypoints\n xy = np.ones((n * nkpt, 3), dtype=keypoints.dtype)\n visible = keypoints[..., 2].reshape(n * nkpt, 1)\n xy[:, :2] = keypoints[..., :2].reshape(n * nkpt, 2)\n xy = xy @ M.T # transform\n xy = xy[:, :2] / xy[:, 2:3] # perspective rescale or affine\n out_mask = (xy[:, 0] < 0) | (xy[:, 1] < 0) | (xy[:, 0] > self.size[0]) | (xy[:, 1] > self.size[1])\n visible[out_mask] = 0\n return np.concatenate([xy, visible], axis=-1).reshape(n, nkpt, 3)\n\n def __call__(self, labels: Dict[str, Any]) -> Dict[str, Any]:\n \"\"\"\n Apply random perspective and affine transformations to an image and its associated labels.\n\n This method performs a series of transformations including rotation, translation, scaling, shearing,\n and perspective distortion on the input image and adjusts the corresponding bounding boxes, segments,\n and keypoints accordingly.\n\n Args:\n labels (Dict[str, Any]): A dictionary containing image data and annotations.\n Must include:\n 'img' (np.ndarray): The input image.\n 'cls' (np.ndarray): Class labels.\n 'instances' (Instances): Object instances with bounding boxes, segments, and keypoints.\n May include:\n 'mosaic_border' (Tuple[int, int]): Border size for mosaic augmentation.\n\n Returns:\n (Dict[str, Any]): Transformed labels dictionary containing:\n - 'img' (np.ndarray): The transformed image.\n - 'cls' (np.ndarray): Updated class labels.\n - 'instances' (Instances): Updated object instances.\n - 'resized_shape' (Tuple[int, int]): New image shape after transformation.\n\n Examples:\n >>> transform = RandomPerspective()\n >>> image = np.random.randint(0, 255, (640, 640, 3), dtype=np.uint8)\n >>> labels = {\n ... \"img\": image,\n ... \"cls\": np.array([0, 1, 2]),\n ... \"instances\": Instances(bboxes=np.array([[10, 10, 50, 50], [100, 100, 150, 150]])),\n ... }\n >>> result = transform(labels)\n >>> assert result[\"img\"].shape[:2] == result[\"resized_shape\"]\n \"\"\"\n if self.pre_transform and \"mosaic_border\" not in labels:\n labels = self.pre_transform(labels)\n labels.pop(\"ratio_pad\", None) # do not need ratio pad\n\n img = labels[\"img\"]\n cls = labels[\"cls\"]\n instances = labels.pop(\"instances\")\n # Make sure the coord formats are right\n instances.convert_bbox(format=\"xyxy\")\n instances.denormalize(*img.shape[:2][::-1])\n\n border = labels.pop(\"mosaic_border\", self.border)\n self.size = img.shape[1] + border[1] * 2, img.shape[0] + border[0] * 2 # w, h\n # M is affine matrix\n # Scale for func:`box_candidates`\n img, M, scale = self.affine_transform(img, border)\n\n bboxes = self.apply_bboxes(instances.bboxes, M)\n\n segments = instances.segments\n keypoints = instances.keypoints\n # Update bboxes if there are segments.\n if len(segments):\n bboxes, segments = self.apply_segments(segments, M)\n\n if keypoints is not None:\n keypoints = self.apply_keypoints(keypoints, M)\n new_instances = Instances(bboxes, segments, keypoints, bbox_format=\"xyxy\", normalized=False)\n # Clip\n new_instances.clip(*self.size)\n\n # Filter instances\n instances.scale(scale_w=scale, scale_h=scale, bbox_only=True)\n # Make the bboxes have the same scale with new_bboxes\n i = self.box_candidates(\n box1=instances.bboxes.T, box2=new_instances.bboxes.T, area_thr=0.01 if len(segments) else 0.10\n )\n labels[\"instances\"] = new_instances[i]\n labels[\"cls\"] = cls[i]\n labels[\"img\"] = img\n labels[\"resized_shape\"] = img.shape[:2]\n return labels\n\n @staticmethod\n def box_candidates(\n box1: np.ndarray,\n box2: np.ndarray,\n wh_thr: int = 2,\n ar_thr: int = 100,\n area_thr: float = 0.1,\n eps: float = 1e-16,\n ) -> np.ndarray:\n \"\"\"\n Compute candidate boxes for further processing based on size and aspect ratio criteria.\n\n This method compares boxes before and after augmentation to determine if they meet specified\n thresholds for width, height, aspect ratio, and area. It's used to filter out boxes that have\n been overly distorted or reduced by the augmentation process.\n\n Args:\n box1 (np.ndarray): Original boxes before augmentation, shape (4, N) where n is the\n number of boxes. Format is [x1, y1, x2, y2] in absolute coordinates.\n box2 (np.ndarray): Augmented boxes after transformation, shape (4, N). Format is\n [x1, y1, x2, y2] in absolute coordinates.\n wh_thr (int): Width and height threshold in pixels. Boxes smaller than this in either\n dimension are rejected.\n ar_thr (int): Aspect ratio threshold. Boxes with an aspect ratio greater than this\n value are rejected.\n area_thr (float): Area ratio threshold. Boxes with an area ratio (new/old) less than\n this value are rejected.\n eps (float): Small epsilon value to prevent division by zero.\n\n Returns:\n (np.ndarray): Boolean array of shape (n) indicating which boxes are candidates.\n True values correspond to boxes that meet all criteria.\n\n Examples:\n >>> random_perspective = RandomPerspective()\n >>> box1 = np.array([[0, 0, 100, 100], [0, 0, 50, 50]]).T\n >>> box2 = np.array([[10, 10, 90, 90], [5, 5, 45, 45]]).T\n >>> candidates = random_perspective.box_candidates(box1, box2)\n >>> print(candidates)\n [True True]\n \"\"\"\n w1, h1 = box1[2] - box1[0], box1[3] - box1[1]\n w2, h2 = box2[2] - box2[0], box2[3] - box2[1]\n ar = np.maximum(w2 / (h2 + eps), h2 / (w2 + eps)) # aspect ratio\n return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + eps) > area_thr) & (ar < ar_thr) # candidates", "chunk_type": "class", "name": "RandomPerspective", "file_path": "ultralytics\\ultralytics\\data\\augment.py", "start_line": 1048, "end_line": 1410, "start_col": 0, "end_col": 101, "parent_name": null, "docstring": "Implement random perspective and affine transformations on images and corresponding annotations.\n\nThis class applies random rotations, translations, scaling, shearing, and perspective transformations\nto images and their associated bounding boxes, segments, and keypoints. It can be used as part of an\naugmentation pipeline for object detection and instance segmentation tasks.\n\nAttributes:\n degrees (float): Maximum absolute degree range for random rotations.\n translate (float): Maximum translation as a fraction of the image size.\n scale (float): Scaling factor range, e.g., scale=0.1 means 0.9-1.1.\n shear (float): Maximum shear angle in degrees.\n perspective (float): Perspective distortion factor.\n border (Tuple[int, int]): Mosaic border size as (x, y).\n pre_transform (Callable | None): Optional transform to apply before the random perspective.\n\nMethods:\n affine_transform: Apply affine transformations to the input image.\n apply_bboxes: Transform bounding boxes using the affine matrix.\n apply_segments: Transform segments and generate new bounding boxes.\n apply_keypoints: Transform keypoints using the affine matrix.\n __call__: Apply the random perspective transformation to images and annotations.\n box_candidates: Filter transformed bounding boxes based on size and aspect ratio.\n\nExamples:\n >>> transform = RandomPerspective(degrees=10, translate=0.1, scale=0.1, shear=10)\n >>> image = np.random.randint(0, 255, (640, 640, 3), dtype=np.uint8)\n >>> labels = {\"img\": image, \"cls\": np.array([0, 1]), \"instances\": Instances(...)}\n >>> result = transform(labels)\n >>> transformed_image = result[\"img\"]\n >>> transformed_instances = result[\"instances\"]", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "math", "random", "copy.deepcopy", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "typing.Union", "cv2", "numpy", "torch", "PIL.Image", "torch.nn.functional", "ultralytics.data.utils.polygons2masks", "ultralytics.data.utils.polygons2masks_overlap", "ultralytics.utils.LOGGER", "ultralytics.utils.IterableSimpleNamespace", "ultralytics.utils.colorstr", "ultralytics.utils.checks.check_version", "ultralytics.utils.instance.Instances", "ultralytics.utils.metrics.bbox_ioa", "ultralytics.utils.ops.segment2box", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xyxyxyxy2xywhr", "ultralytics.utils.torch_utils.TORCHVISION_0_10", "ultralytics.utils.torch_utils.TORCHVISION_0_11", "ultralytics.utils.torch_utils.TORCHVISION_0_13", "torchvision.transforms", "torchvision.transforms", "os", "albumentations" ], "chunk_id": "class_RandomPerspective_b28e4daf" }, { "content": "class RandomHSV:\n \"\"\"\n Randomly adjust the Hue, Saturation, and Value (HSV) channels of an image.\n\n This class applies random HSV augmentation to images within predefined limits set by hgain, sgain, and vgain.\n\n Attributes:\n hgain (float): Maximum variation for hue. Range is typically [0, 1].\n sgain (float): Maximum variation for saturation. Range is typically [0, 1].\n vgain (float): Maximum variation for value. Range is typically [0, 1].\n\n Methods:\n __call__: Apply random HSV augmentation to an image.\n\n Examples:\n >>> import numpy as np\n >>> from ultralytics.data.augment import RandomHSV\n >>> augmenter = RandomHSV(hgain=0.5, sgain=0.5, vgain=0.5)\n >>> image = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)\n >>> labels = {\"img\": image}\n >>> augmenter(labels)\n >>> augmented_image = augmented_labels[\"img\"]\n \"\"\"\n\n def __init__(self, hgain: float = 0.5, sgain: float = 0.5, vgain: float = 0.5) -> None:\n \"\"\"\n Initialize the RandomHSV object for random HSV (Hue, Saturation, Value) augmentation.\n\n This class applies random adjustments to the HSV channels of an image within specified limits.\n\n Args:\n hgain (float): Maximum variation for hue. Should be in the range [0, 1].\n sgain (float): Maximum variation for saturation. Should be in the range [0, 1].\n vgain (float): Maximum variation for value. Should be in the range [0, 1].\n\n Examples:\n >>> hsv_aug = RandomHSV(hgain=0.5, sgain=0.5, vgain=0.5)\n >>> hsv_aug(image)\n \"\"\"\n self.hgain = hgain\n self.sgain = sgain\n self.vgain = vgain\n\n def __call__(self, labels: Dict[str, Any]) -> Dict[str, Any]:\n \"\"\"\n Apply random HSV augmentation to an image within predefined limits.\n\n This method modifies the input image by randomly adjusting its Hue, Saturation, and Value (HSV) channels.\n The adjustments are made within the limits set by hgain, sgain, and vgain during initialization.\n\n Args:\n labels (Dict[str, Any]): A dictionary containing image data and metadata. Must include an 'img' key with\n the image as a numpy array.\n\n Returns:\n (Dict[str, Any]): A dictionary containing the mixed image and adjusted labels.\n\n Examples:\n >>> hsv_augmenter = RandomHSV(hgain=0.5, sgain=0.5, vgain=0.5)\n >>> labels = {\"img\": np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)}\n >>> labels = hsv_augmenter(labels)\n >>> augmented_img = labels[\"img\"]\n \"\"\"\n img = labels[\"img\"]\n if img.shape[-1] != 3: # only apply to RGB images\n return labels\n if self.hgain or self.sgain or self.vgain:\n dtype = img.dtype # uint8\n\n r = np.random.uniform(-1, 1, 3) * [self.hgain, self.sgain, self.vgain] # random gains\n x = np.arange(0, 256, dtype=r.dtype)\n # lut_hue = ((x * (r[0] + 1)) % 180).astype(dtype) # original hue implementation from ultralytics<=8.3.78\n lut_hue = ((x + r[0] * 180) % 180).astype(dtype)\n lut_sat = np.clip(x * (r[1] + 1), 0, 255).astype(dtype)\n lut_val = np.clip(x * (r[2] + 1), 0, 255).astype(dtype)\n lut_sat[0] = 0 # prevent pure white changing color, introduced in 8.3.79\n\n hue, sat, val = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV))\n im_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))\n cv2.cvtColor(im_hsv, cv2.COLOR_HSV2BGR, dst=img) # no return needed\n return labels", "chunk_type": "class", "name": "RandomHSV", "file_path": "ultralytics\\ultralytics\\data\\augment.py", "start_line": 1413, "end_line": 1493, "start_col": 0, "end_col": 21, "parent_name": null, "docstring": "Randomly adjust the Hue, Saturation, and Value (HSV) channels of an image.\n\nThis class applies random HSV augmentation to images within predefined limits set by hgain, sgain, and vgain.\n\nAttributes:\n hgain (float): Maximum variation for hue. Range is typically [0, 1].\n sgain (float): Maximum variation for saturation. Range is typically [0, 1].\n vgain (float): Maximum variation for value. Range is typically [0, 1].\n\nMethods:\n __call__: Apply random HSV augmentation to an image.\n\nExamples:\n >>> import numpy as np\n >>> from ultralytics.data.augment import RandomHSV\n >>> augmenter = RandomHSV(hgain=0.5, sgain=0.5, vgain=0.5)\n >>> image = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)\n >>> labels = {\"img\": image}\n >>> augmenter(labels)\n >>> augmented_image = augmented_labels[\"img\"]", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "math", "random", "copy.deepcopy", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "typing.Union", "cv2", "numpy", "torch", "PIL.Image", "torch.nn.functional", "ultralytics.data.utils.polygons2masks", "ultralytics.data.utils.polygons2masks_overlap", "ultralytics.utils.LOGGER", "ultralytics.utils.IterableSimpleNamespace", "ultralytics.utils.colorstr", "ultralytics.utils.checks.check_version", "ultralytics.utils.instance.Instances", "ultralytics.utils.metrics.bbox_ioa", "ultralytics.utils.ops.segment2box", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xyxyxyxy2xywhr", "ultralytics.utils.torch_utils.TORCHVISION_0_10", "ultralytics.utils.torch_utils.TORCHVISION_0_11", "ultralytics.utils.torch_utils.TORCHVISION_0_13", "torchvision.transforms", "torchvision.transforms", "os", "albumentations" ], "chunk_id": "class_RandomHSV_9752cddd" }, { "content": "class RandomFlip:\n \"\"\"\n Apply a random horizontal or vertical flip to an image with a given probability.\n\n This class performs random image flipping and updates corresponding instance annotations such as\n bounding boxes and keypoints.\n\n Attributes:\n p (float): Probability of applying the flip. Must be between 0 and 1.\n direction (str): Direction of flip, either 'horizontal' or 'vertical'.\n flip_idx (array-like): Index mapping for flipping keypoints, if applicable.\n\n Methods:\n __call__: Apply the random flip transformation to an image and its annotations.\n\n Examples:\n >>> transform = RandomFlip(p=0.5, direction=\"horizontal\")\n >>> result = transform({\"img\": image, \"instances\": instances})\n >>> flipped_image = result[\"img\"]\n >>> flipped_instances = result[\"instances\"]\n \"\"\"\n\n def __init__(self, p: float = 0.5, direction: str = \"horizontal\", flip_idx: List[int] = None) -> None:\n \"\"\"\n Initialize the RandomFlip class with probability and direction.\n\n This class applies a random horizontal or vertical flip to an image with a given probability.\n It also updates any instances (bounding boxes, keypoints, etc.) accordingly.\n\n Args:\n p (float): The probability of applying the flip. Must be between 0 and 1.\n direction (str): The direction to apply the flip. Must be 'horizontal' or 'vertical'.\n flip_idx (List[int] | None): Index mapping for flipping keypoints, if any.\n\n Raises:\n AssertionError: If direction is not 'horizontal' or 'vertical', or if p is not between 0 and 1.\n\n Examples:\n >>> flip = RandomFlip(p=0.5, direction=\"horizontal\")\n >>> flip_with_idx = RandomFlip(p=0.7, direction=\"vertical\", flip_idx=[1, 0, 3, 2, 5, 4])\n \"\"\"\n assert direction in {\"horizontal\", \"vertical\"}, f\"Support direction `horizontal` or `vertical`, got {direction}\"\n assert 0 <= p <= 1.0, f\"The probability should be in range [0, 1], but got {p}.\"\n\n self.p = p\n self.direction = direction\n self.flip_idx = flip_idx\n\n def __call__(self, labels: Dict[str, Any]) -> Dict[str, Any]:\n \"\"\"\n Apply random flip to an image and update any instances like bounding boxes or keypoints accordingly.\n\n This method randomly flips the input image either horizontally or vertically based on the initialized\n probability and direction. It also updates the corresponding instances (bounding boxes, keypoints) to\n match the flipped image.\n\n Args:\n labels (Dict[str, Any]): A dictionary containing the following keys:\n 'img' (np.ndarray): The image to be flipped.\n 'instances' (ultralytics.utils.instance.Instances): An object containing bounding boxes and\n optionally keypoints.\n\n Returns:\n (Dict[str, Any]): The same dictionary with the flipped image and updated instances:\n 'img' (np.ndarray): The flipped image.\n 'instances' (ultralytics.utils.instance.Instances): Updated instances matching the flipped image.\n\n Examples:\n >>> labels = {\"img\": np.random.rand(640, 640, 3), \"instances\": Instances(...)}\n >>> random_flip = RandomFlip(p=0.5, direction=\"horizontal\")\n >>> flipped_labels = random_flip(labels)\n \"\"\"\n img = labels[\"img\"]\n instances = labels.pop(\"instances\")\n instances.convert_bbox(format=\"xywh\")\n h, w = img.shape[:2]\n h = 1 if instances.normalized else h\n w = 1 if instances.normalized else w\n\n # WARNING: two separate if and calls to random.random() intentional for reproducibility with older versions\n if self.direction == \"vertical\" and random.random() < self.p:\n img = np.flipud(img)\n instances.flipud(h)\n if self.flip_idx is not None and instances.keypoints is not None:\n instances.keypoints = np.ascontiguousarray(instances.keypoints[:, self.flip_idx, :])\n if self.direction == \"horizontal\" and random.random() < self.p:\n img = np.fliplr(img)\n instances.fliplr(w)\n if self.flip_idx is not None and instances.keypoints is not None:\n instances.keypoints = np.ascontiguousarray(instances.keypoints[:, self.flip_idx, :])\n labels[\"img\"] = np.ascontiguousarray(img)\n labels[\"instances\"] = instances\n return labels", "chunk_type": "class", "name": "RandomFlip", "file_path": "ultralytics\\ultralytics\\data\\augment.py", "start_line": 1496, "end_line": 1588, "start_col": 0, "end_col": 21, "parent_name": null, "docstring": "Apply a random horizontal or vertical flip to an image with a given probability.\n\nThis class performs random image flipping and updates corresponding instance annotations such as\nbounding boxes and keypoints.\n\nAttributes:\n p (float): Probability of applying the flip. Must be between 0 and 1.\n direction (str): Direction of flip, either 'horizontal' or 'vertical'.\n flip_idx (array-like): Index mapping for flipping keypoints, if applicable.\n\nMethods:\n __call__: Apply the random flip transformation to an image and its annotations.\n\nExamples:\n >>> transform = RandomFlip(p=0.5, direction=\"horizontal\")\n >>> result = transform({\"img\": image, \"instances\": instances})\n >>> flipped_image = result[\"img\"]\n >>> flipped_instances = result[\"instances\"]", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "math", "random", "copy.deepcopy", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "typing.Union", "cv2", "numpy", "torch", "PIL.Image", "torch.nn.functional", "ultralytics.data.utils.polygons2masks", "ultralytics.data.utils.polygons2masks_overlap", "ultralytics.utils.LOGGER", "ultralytics.utils.IterableSimpleNamespace", "ultralytics.utils.colorstr", "ultralytics.utils.checks.check_version", "ultralytics.utils.instance.Instances", "ultralytics.utils.metrics.bbox_ioa", "ultralytics.utils.ops.segment2box", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xyxyxyxy2xywhr", "ultralytics.utils.torch_utils.TORCHVISION_0_10", "ultralytics.utils.torch_utils.TORCHVISION_0_11", "ultralytics.utils.torch_utils.TORCHVISION_0_13", "torchvision.transforms", "torchvision.transforms", "os", "albumentations" ], "chunk_id": "class_RandomFlip_0fc993b3" }, { "content": "class LetterBox:\n \"\"\"\n Resize image and padding for detection, instance segmentation, pose.\n\n This class resizes and pads images to a specified shape while preserving aspect ratio. It also updates\n corresponding labels and bounding boxes.\n\n Attributes:\n new_shape (tuple): Target shape (height, width) for resizing.\n auto (bool): Whether to use minimum rectangle.\n scale_fill (bool): Whether to stretch the image to new_shape.\n scaleup (bool): Whether to allow scaling up. If False, only scale down.\n stride (int): Stride for rounding padding.\n center (bool): Whether to center the image or align to top-left.\n\n Methods:\n __call__: Resize and pad image, update labels and bounding boxes.\n\n Examples:\n >>> transform = LetterBox(new_shape=(640, 640))\n >>> result = transform(labels)\n >>> resized_img = result[\"img\"]\n >>> updated_instances = result[\"instances\"]\n \"\"\"\n\n def __init__(\n self,\n new_shape: Tuple[int, int] = (640, 640),\n auto: bool = False,\n scale_fill: bool = False,\n scaleup: bool = True,\n center: bool = True,\n stride: int = 32,\n ):\n \"\"\"\n Initialize LetterBox object for resizing and padding images.\n\n This class is designed to resize and pad images for object detection, instance segmentation, and pose estimation\n tasks. It supports various resizing modes including auto-sizing, scale-fill, and letterboxing.\n\n Args:\n new_shape (Tuple[int, int]): Target size (height, width) for the resized image.\n auto (bool): If True, use minimum rectangle to resize. If False, use new_shape directly.\n scale_fill (bool): If True, stretch the image to new_shape without padding.\n scaleup (bool): If True, allow scaling up. If False, only scale down.\n center (bool): If True, center the placed image. If False, place image in top-left corner.\n stride (int): Stride of the model (e.g., 32 for YOLOv5).\n\n Attributes:\n new_shape (Tuple[int, int]): Target size for the resized image.\n auto (bool): Flag for using minimum rectangle resizing.\n scale_fill (bool): Flag for stretching image without padding.\n scaleup (bool): Flag for allowing upscaling.\n stride (int): Stride value for ensuring image size is divisible by stride.\n\n Examples:\n >>> letterbox = LetterBox(new_shape=(640, 640), auto=False, scale_fill=False, scaleup=True, stride=32)\n >>> resized_img = letterbox(original_img)\n \"\"\"\n self.new_shape = new_shape\n self.auto = auto\n self.scale_fill = scale_fill\n self.scaleup = scaleup\n self.stride = stride\n self.center = center # Put the image in the middle or top-left\n\n def __call__(self, labels: Dict[str, Any] = None, image: np.ndarray = None) -> Union[Dict[str, Any], np.ndarray]:\n \"\"\"\n Resize and pad an image for object detection, instance segmentation, or pose estimation tasks.\n\n This method applies letterboxing to the input image, which involves resizing the image while maintaining its\n aspect ratio and adding padding to fit the new shape. It also updates any associated labels accordingly.\n\n Args:\n labels (Dict[str, Any] | None): A dictionary containing image data and associated labels, or empty dict if None.\n image (np.ndarray | None): The input image as a numpy array. If None, the image is taken from 'labels'.\n\n Returns:\n (Dict[str, Any] | nd.ndarray): If 'labels' is provided, returns an updated dictionary with the resized and padded image,\n updated labels, and additional metadata. If 'labels' is empty, returns the resized\n and padded image.\n\n Examples:\n >>> letterbox = LetterBox(new_shape=(640, 640))\n >>> result = letterbox(labels={\"img\": np.zeros((480, 640, 3)), \"instances\": Instances(...)})\n >>> resized_img = result[\"img\"]\n >>> updated_instances = result[\"instances\"]\n \"\"\"\n if labels is None:\n labels = {}\n img = labels.get(\"img\") if image is None else image\n shape = img.shape[:2] # current shape [height, width]\n new_shape = labels.pop(\"rect_shape\", self.new_shape)\n if isinstance(new_shape, int):\n new_shape = (new_shape, new_shape)\n\n # Scale ratio (new / old)\n r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])\n if not self.scaleup: # only scale down, do not scale up (for better val mAP)\n r = min(r, 1.0)\n\n # Compute padding\n ratio = r, r # width, height ratios\n new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))\n dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding\n if self.auto: # minimum rectangle\n dw, dh = np.mod(dw, self.stride), np.mod(dh, self.stride) # wh padding\n elif self.scale_fill: # stretch\n dw, dh = 0.0, 0.0\n new_unpad = (new_shape[1], new_shape[0])\n ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios\n\n if self.center:\n dw /= 2 # divide padding into 2 sides\n dh /= 2\n\n if shape[::-1] != new_unpad: # resize\n img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)\n if img.ndim == 2:\n img = img[..., None]\n\n top, bottom = int(round(dh - 0.1)) if self.center else 0, int(round(dh + 0.1))\n left, right = int(round(dw - 0.1)) if self.center else 0, int(round(dw + 0.1))\n h, w, c = img.shape\n if c == 3:\n img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114))\n else: # multispectral\n pad_img = np.full((h + top + bottom, w + left + right, c), fill_value=114, dtype=img.dtype)\n pad_img[top : top + h, left : left + w] = img\n img = pad_img\n\n if labels.get(\"ratio_pad\"):\n labels[\"ratio_pad\"] = (labels[\"ratio_pad\"], (left, top)) # for evaluation\n\n if len(labels):\n labels = self._update_labels(labels, ratio, left, top)\n labels[\"img\"] = img\n labels[\"resized_shape\"] = new_shape\n return labels\n else:\n return img\n\n @staticmethod\n def _update_labels(labels: Dict[str, Any], ratio: Tuple[float, float], padw: float, padh: float) -> Dict[str, Any]:\n \"\"\"\n Update labels after applying letterboxing to an image.\n\n This method modifies the bounding box coordinates of instances in the labels\n to account for resizing and padding applied during letterboxing.\n\n Args:\n labels (Dict[str, Any]): A dictionary containing image labels and instances.\n ratio (Tuple[float, float]): Scaling ratios (width, height) applied to the image.\n padw (float): Padding width added to the image.\n padh (float): Padding height added to the image.\n\n Returns:\n (Dict[str, Any]): Updated labels dictionary with modified instance coordinates.\n\n Examples:\n >>> letterbox = LetterBox(new_shape=(640, 640))\n >>> labels = {\"instances\": Instances(...)}\n >>> ratio = (0.5, 0.5)\n >>> padw, padh = 10, 20\n >>> updated_labels = letterbox._update_labels(labels, ratio, padw, padh)\n \"\"\"\n labels[\"instances\"].convert_bbox(format=\"xyxy\")\n labels[\"instances\"].denormalize(*labels[\"img\"].shape[:2][::-1])\n labels[\"instances\"].scale(*ratio)\n labels[\"instances\"].add_padding(padw, padh)\n return labels", "chunk_type": "class", "name": "LetterBox", "file_path": "ultralytics\\ultralytics\\data\\augment.py", "start_line": 1591, "end_line": 1761, "start_col": 0, "end_col": 21, "parent_name": null, "docstring": "Resize image and padding for detection, instance segmentation, pose.\n\nThis class resizes and pads images to a specified shape while preserving aspect ratio. It also updates\ncorresponding labels and bounding boxes.\n\nAttributes:\n new_shape (tuple): Target shape (height, width) for resizing.\n auto (bool): Whether to use minimum rectangle.\n scale_fill (bool): Whether to stretch the image to new_shape.\n scaleup (bool): Whether to allow scaling up. If False, only scale down.\n stride (int): Stride for rounding padding.\n center (bool): Whether to center the image or align to top-left.\n\nMethods:\n __call__: Resize and pad image, update labels and bounding boxes.\n\nExamples:\n >>> transform = LetterBox(new_shape=(640, 640))\n >>> result = transform(labels)\n >>> resized_img = result[\"img\"]\n >>> updated_instances = result[\"instances\"]", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "math", "random", "copy.deepcopy", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "typing.Union", "cv2", "numpy", "torch", "PIL.Image", "torch.nn.functional", "ultralytics.data.utils.polygons2masks", "ultralytics.data.utils.polygons2masks_overlap", "ultralytics.utils.LOGGER", "ultralytics.utils.IterableSimpleNamespace", "ultralytics.utils.colorstr", "ultralytics.utils.checks.check_version", "ultralytics.utils.instance.Instances", "ultralytics.utils.metrics.bbox_ioa", "ultralytics.utils.ops.segment2box", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xyxyxyxy2xywhr", "ultralytics.utils.torch_utils.TORCHVISION_0_10", "ultralytics.utils.torch_utils.TORCHVISION_0_11", "ultralytics.utils.torch_utils.TORCHVISION_0_13", "torchvision.transforms", "torchvision.transforms", "os", "albumentations" ], "chunk_id": "class_LetterBox_821d9e2a" }, { "content": "class CopyPaste(BaseMixTransform):\n \"\"\"\n CopyPaste class for applying Copy-Paste augmentation to image datasets.\n\n This class implements the Copy-Paste augmentation technique as described in the paper \"Simple Copy-Paste is a Strong\n Data Augmentation Method for Instance Segmentation\" (https://arxiv.org/abs/2012.07177). It combines objects from\n different images to create new training samples.\n\n Attributes:\n dataset (Any): The dataset to which Copy-Paste augmentation will be applied.\n pre_transform (Callable | None): Optional transform to apply before Copy-Paste.\n p (float): Probability of applying Copy-Paste augmentation.\n\n Methods:\n _mix_transform: Apply Copy-Paste augmentation to the input labels.\n __call__: Apply the Copy-Paste transformation to images and annotations.\n\n Examples:\n >>> from ultralytics.data.augment import CopyPaste\n >>> dataset = YourDataset(...) # Your image dataset\n >>> copypaste = CopyPaste(dataset, p=0.5)\n >>> augmented_labels = copypaste(original_labels)\n \"\"\"\n\n def __init__(self, dataset=None, pre_transform=None, p: float = 0.5, mode: str = \"flip\") -> None:\n \"\"\"Initialize CopyPaste object with dataset, pre_transform, and probability of applying MixUp.\"\"\"\n super().__init__(dataset=dataset, pre_transform=pre_transform, p=p)\n assert mode in {\"flip\", \"mixup\"}, f\"Expected `mode` to be `flip` or `mixup`, but got {mode}.\"\n self.mode = mode\n\n def _mix_transform(self, labels: Dict[str, Any]) -> Dict[str, Any]:\n \"\"\"Apply Copy-Paste augmentation to combine objects from another image into the current image.\"\"\"\n labels2 = labels[\"mix_labels\"][0]\n return self._transform(labels, labels2)\n\n def __call__(self, labels: Dict[str, Any]) -> Dict[str, Any]:\n \"\"\"Apply Copy-Paste augmentation to an image and its labels.\"\"\"\n if len(labels[\"instances\"].segments) == 0 or self.p == 0:\n return labels\n if self.mode == \"flip\":\n return self._transform(labels)\n\n # Get index of one or three other images\n indexes = self.get_indexes()\n if isinstance(indexes, int):\n indexes = [indexes]\n\n # Get images information will be used for Mosaic or MixUp\n mix_labels = [self.dataset.get_image_and_label(i) for i in indexes]\n\n if self.pre_transform is not None:\n for i, data in enumerate(mix_labels):\n mix_labels[i] = self.pre_transform(data)\n labels[\"mix_labels\"] = mix_labels\n\n # Update cls and texts\n labels = self._update_label_text(labels)\n # Mosaic or MixUp\n labels = self._mix_transform(labels)\n labels.pop(\"mix_labels\", None)\n return labels\n\n def _transform(self, labels1: Dict[str, Any], labels2: Dict[str, Any] = {}) -> Dict[str, Any]:\n \"\"\"Apply Copy-Paste augmentation to combine objects from another image into the current image.\"\"\"\n im = labels1[\"img\"]\n if \"mosaic_border\" not in labels1:\n im = im.copy() # avoid modifying original non-mosaic image\n cls = labels1[\"cls\"]\n h, w = im.shape[:2]\n instances = labels1.pop(\"instances\")\n instances.convert_bbox(format=\"xyxy\")\n instances.denormalize(w, h)\n\n im_new = np.zeros(im.shape, np.uint8)\n instances2 = labels2.pop(\"instances\", None)\n if instances2 is None:\n instances2 = deepcopy(instances)\n instances2.fliplr(w)\n ioa = bbox_ioa(instances2.bboxes, instances.bboxes) # intersection over area, (N, M)\n indexes = np.nonzero((ioa < 0.30).all(1))[0] # (N, )\n n = len(indexes)\n sorted_idx = np.argsort(ioa.max(1)[indexes])\n indexes = indexes[sorted_idx]\n for j in indexes[: round(self.p * n)]:\n cls = np.concatenate((cls, labels2.get(\"cls\", cls)[[j]]), axis=0)\n instances = Instances.concatenate((instances, instances2[[j]]), axis=0)\n cv2.drawContours(im_new, instances2.segments[[j]].astype(np.int32), -1, (1, 1, 1), cv2.FILLED)\n\n result = labels2.get(\"img\", cv2.flip(im, 1)) # augment segments\n if result.ndim == 2: # cv2.flip would eliminate the last dimension for grayscale images\n result = result[..., None]\n i = im_new.astype(bool)\n im[i] = result[i]\n\n labels1[\"img\"] = im\n labels1[\"cls\"] = cls\n labels1[\"instances\"] = instances\n return labels1", "chunk_type": "class", "name": "CopyPaste", "file_path": "ultralytics\\ultralytics\\data\\augment.py", "start_line": 1764, "end_line": 1861, "start_col": 0, "end_col": 22, "parent_name": null, "docstring": "CopyPaste class for applying Copy-Paste augmentation to image datasets.\n\nThis class implements the Copy-Paste augmentation technique as described in the paper \"Simple Copy-Paste is a Strong\nData Augmentation Method for Instance Segmentation\" (https://arxiv.org/abs/2012.07177). It combines objects from\ndifferent images to create new training samples.\n\nAttributes:\n dataset (Any): The dataset to which Copy-Paste augmentation will be applied.\n pre_transform (Callable | None): Optional transform to apply before Copy-Paste.\n p (float): Probability of applying Copy-Paste augmentation.\n\nMethods:\n _mix_transform: Apply Copy-Paste augmentation to the input labels.\n __call__: Apply the Copy-Paste transformation to images and annotations.\n\nExamples:\n >>> from ultralytics.data.augment import CopyPaste\n >>> dataset = YourDataset(...) # Your image dataset\n >>> copypaste = CopyPaste(dataset, p=0.5)\n >>> augmented_labels = copypaste(original_labels)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "math", "random", "copy.deepcopy", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "typing.Union", "cv2", "numpy", "torch", "PIL.Image", "torch.nn.functional", "ultralytics.data.utils.polygons2masks", "ultralytics.data.utils.polygons2masks_overlap", "ultralytics.utils.LOGGER", "ultralytics.utils.IterableSimpleNamespace", "ultralytics.utils.colorstr", "ultralytics.utils.checks.check_version", "ultralytics.utils.instance.Instances", "ultralytics.utils.metrics.bbox_ioa", "ultralytics.utils.ops.segment2box", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xyxyxyxy2xywhr", "ultralytics.utils.torch_utils.TORCHVISION_0_10", "ultralytics.utils.torch_utils.TORCHVISION_0_11", "ultralytics.utils.torch_utils.TORCHVISION_0_13", "torchvision.transforms", "torchvision.transforms", "os", "albumentations", "BaseMixTransform" ], "chunk_id": "class_CopyPaste_444abb06" }, { "content": "class Albumentations:\n \"\"\"\n Albumentations transformations for image augmentation.\n\n This class applies various image transformations using the Albumentations library. It includes operations such as\n Blur, Median Blur, conversion to grayscale, Contrast Limited Adaptive Histogram Equalization (CLAHE), random changes\n in brightness and contrast, RandomGamma, and image quality reduction through compression.\n\n Attributes:\n p (float): Probability of applying the transformations.\n transform (albumentations.Compose): Composed Albumentations transforms.\n contains_spatial (bool): Indicates if the transforms include spatial operations.\n\n Methods:\n __call__: Apply the Albumentations transformations to the input labels.\n\n Examples:\n >>> transform = Albumentations(p=0.5)\n >>> augmented_labels = transform(labels)\n\n Notes:\n - The Albumentations package must be installed to use this class.\n - If the package is not installed or an error occurs during initialization, the transform will be set to None.\n - Spatial transforms are handled differently and require special processing for bounding boxes.\n \"\"\"\n\n def __init__(self, p: float = 1.0) -> None:\n \"\"\"\n Initialize the Albumentations transform object for YOLO bbox formatted parameters.\n\n This class applies various image augmentations using the Albumentations library, including Blur, Median Blur,\n conversion to grayscale, Contrast Limited Adaptive Histogram Equalization, random changes of brightness and\n contrast, RandomGamma, and image quality reduction through compression.\n\n Args:\n p (float): Probability of applying the augmentations. Must be between 0 and 1.\n\n Attributes:\n p (float): Probability of applying the augmentations.\n transform (albumentations.Compose): Composed Albumentations transforms.\n contains_spatial (bool): Indicates if the transforms include spatial transformations.\n\n Raises:\n ImportError: If the Albumentations package is not installed.\n Exception: For any other errors during initialization.\n\n Examples:\n >>> transform = Albumentations(p=0.5)\n >>> augmented = transform(image=image, bboxes=bboxes, class_labels=classes)\n >>> augmented_image = augmented[\"image\"]\n >>> augmented_bboxes = augmented[\"bboxes\"]\n\n Notes:\n - Requires Albumentations version 1.0.3 or higher.\n - Spatial transforms are handled differently to ensure bbox compatibility.\n - Some transforms are applied with very low probability (0.01) by default.\n \"\"\"\n self.p = p\n self.transform = None\n prefix = colorstr(\"albumentations: \")\n\n try:\n import os\n\n os.environ[\"NO_ALBUMENTATIONS_UPDATE\"] = \"1\" # suppress Albumentations upgrade message\n import albumentations as A\n\n check_version(A.__version__, \"1.0.3\", hard=True) # version requirement\n\n # List of possible spatial transforms\n spatial_transforms = {\n \"Affine\",\n \"BBoxSafeRandomCrop\",\n \"CenterCrop\",\n \"CoarseDropout\",\n \"Crop\",\n \"CropAndPad\",\n \"CropNonEmptyMaskIfExists\",\n \"D4\",\n \"ElasticTransform\",\n \"Flip\",\n \"GridDistortion\",\n \"GridDropout\",\n \"HorizontalFlip\",\n \"Lambda\",\n \"LongestMaxSize\",\n \"MaskDropout\",\n \"MixUp\",\n \"Morphological\",\n \"NoOp\",\n \"OpticalDistortion\",\n \"PadIfNeeded\",\n \"Perspective\",\n \"PiecewiseAffine\",\n \"PixelDropout\",\n \"RandomCrop\",\n \"RandomCropFromBorders\",\n \"RandomGridShuffle\",\n \"RandomResizedCrop\",\n \"RandomRotate90\",\n \"RandomScale\",\n \"RandomSizedBBoxSafeCrop\",\n \"RandomSizedCrop\",\n \"Resize\",\n \"Rotate\",\n \"SafeRotate\",\n \"ShiftScaleRotate\",\n \"SmallestMaxSize\",\n \"Transpose\",\n \"VerticalFlip\",\n \"XYMasking\",\n } # from https://albumentations.ai/docs/getting_started/transforms_and_targets/#spatial-level-transforms\n\n # Transforms\n T = [\n A.Blur(p=0.01),\n A.MedianBlur(p=0.01),\n A.ToGray(p=0.01),\n A.CLAHE(p=0.01),\n A.RandomBrightnessContrast(p=0.0),\n A.RandomGamma(p=0.0),\n A.ImageCompression(quality_range=(75, 100), p=0.0),\n ]\n\n # Compose transforms\n self.contains_spatial = any(transform.__class__.__name__ in spatial_transforms for transform in T)\n self.transform = (\n A.Compose(T, bbox_params=A.BboxParams(format=\"yolo\", label_fields=[\"class_labels\"]))\n if self.contains_spatial\n else A.Compose(T)\n )\n if hasattr(self.transform, \"set_random_seed\"):\n # Required for deterministic transforms in albumentations>=1.4.21\n self.transform.set_random_seed(torch.initial_seed())\n LOGGER.info(prefix + \", \".join(f\"{x}\".replace(\"always_apply=False, \", \"\") for x in T if x.p))\n except ImportError: # package not installed, skip\n pass\n except Exception as e:\n LOGGER.info(f\"{prefix}{e}\")\n\n def __call__(self, labels: Dict[str, Any]) -> Dict[str, Any]:\n \"\"\"\n Apply Albumentations transformations to input labels.\n\n This method applies a series of image augmentations using the Albumentations library. It can perform both\n spatial and non-spatial transformations on the input image and its corresponding labels.\n\n Args:\n labels (Dict[str, Any]): A dictionary containing image data and annotations. Expected keys are:\n - 'img': np.ndarray representing the image\n - 'cls': np.ndarray of class labels\n - 'instances': object containing bounding boxes and other instance information\n\n Returns:\n (Dict[str, Any]): The input dictionary with augmented image and updated annotations.\n\n Examples:\n >>> transform = Albumentations(p=0.5)\n >>> labels = {\n ... \"img\": np.random.rand(640, 640, 3),\n ... \"cls\": np.array([0, 1]),\n ... \"instances\": Instances(bboxes=np.array([[0, 0, 1, 1], [0.5, 0.5, 0.8, 0.8]])),\n ... }\n >>> augmented = transform(labels)\n >>> assert augmented[\"img\"].shape == (640, 640, 3)\n\n Notes:\n - The method applies transformations with probability self.p.\n - Spatial transforms update bounding boxes, while non-spatial transforms only modify the image.\n - Requires the Albumentations library to be installed.\n \"\"\"\n if self.transform is None or random.random() > self.p:\n return labels\n\n im = labels[\"img\"]\n if im.shape[2] != 3: # Only apply Albumentation on 3-channel images\n return labels\n\n if self.contains_spatial:\n cls = labels[\"cls\"]\n if len(cls):\n labels[\"instances\"].convert_bbox(\"xywh\")\n labels[\"instances\"].normalize(*im.shape[:2][::-1])\n bboxes = labels[\"instances\"].bboxes\n # TODO: add supports of segments and keypoints\n new = self.transform(image=im, bboxes=bboxes, class_labels=cls) # transformed\n if len(new[\"class_labels\"]) > 0: # skip update if no bbox in new im\n labels[\"img\"] = new[\"image\"]\n labels[\"cls\"] = np.array(new[\"class_labels\"])\n bboxes = np.array(new[\"bboxes\"], dtype=np.float32)\n labels[\"instances\"].update(bboxes=bboxes)\n else:\n labels[\"img\"] = self.transform(image=labels[\"img\"])[\"image\"] # transformed\n\n return labels", "chunk_type": "class", "name": "Albumentations", "file_path": "ultralytics\\ultralytics\\data\\augment.py", "start_line": 1864, "end_line": 2058, "start_col": 0, "end_col": 21, "parent_name": null, "docstring": "Albumentations transformations for image augmentation.\n\nThis class applies various image transformations using the Albumentations library. It includes operations such as\nBlur, Median Blur, conversion to grayscale, Contrast Limited Adaptive Histogram Equalization (CLAHE), random changes\nin brightness and contrast, RandomGamma, and image quality reduction through compression.\n\nAttributes:\n p (float): Probability of applying the transformations.\n transform (albumentations.Compose): Composed Albumentations transforms.\n contains_spatial (bool): Indicates if the transforms include spatial operations.\n\nMethods:\n __call__: Apply the Albumentations transformations to the input labels.\n\nExamples:\n >>> transform = Albumentations(p=0.5)\n >>> augmented_labels = transform(labels)\n\nNotes:\n - The Albumentations package must be installed to use this class.\n - If the package is not installed or an error occurs during initialization, the transform will be set to None.\n - Spatial transforms are handled differently and require special processing for bounding boxes.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "math", "random", "copy.deepcopy", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "typing.Union", "cv2", "numpy", "torch", "PIL.Image", "torch.nn.functional", "ultralytics.data.utils.polygons2masks", "ultralytics.data.utils.polygons2masks_overlap", "ultralytics.utils.LOGGER", "ultralytics.utils.IterableSimpleNamespace", "ultralytics.utils.colorstr", "ultralytics.utils.checks.check_version", "ultralytics.utils.instance.Instances", "ultralytics.utils.metrics.bbox_ioa", "ultralytics.utils.ops.segment2box", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xyxyxyxy2xywhr", "ultralytics.utils.torch_utils.TORCHVISION_0_10", "ultralytics.utils.torch_utils.TORCHVISION_0_11", "ultralytics.utils.torch_utils.TORCHVISION_0_13", "torchvision.transforms", "torchvision.transforms", "os", "albumentations" ], "chunk_id": "class_Albumentations_168d40dd" }, { "content": "class Format:\n \"\"\"\n A class for formatting image annotations for object detection, instance segmentation, and pose estimation tasks.\n\n This class standardizes image and instance annotations to be used by the `collate_fn` in PyTorch DataLoader.\n\n Attributes:\n bbox_format (str): Format for bounding boxes. Options are 'xywh' or 'xyxy'.\n normalize (bool): Whether to normalize bounding boxes.\n return_mask (bool): Whether to return instance masks for segmentation.\n return_keypoint (bool): Whether to return keypoints for pose estimation.\n return_obb (bool): Whether to return oriented bounding boxes.\n mask_ratio (int): Downsample ratio for masks.\n mask_overlap (bool): Whether to overlap masks.\n batch_idx (bool): Whether to keep batch indexes.\n bgr (float): The probability to return BGR images.\n\n Methods:\n __call__: Format labels dictionary with image, classes, bounding boxes, and optionally masks and keypoints.\n _format_img: Convert image from Numpy array to PyTorch tensor.\n _format_segments: Convert polygon points to bitmap masks.\n\n Examples:\n >>> formatter = Format(bbox_format=\"xywh\", normalize=True, return_mask=True)\n >>> formatted_labels = formatter(labels)\n >>> img = formatted_labels[\"img\"]\n >>> bboxes = formatted_labels[\"bboxes\"]\n >>> masks = formatted_labels[\"masks\"]\n \"\"\"\n\n def __init__(\n self,\n bbox_format: str = \"xywh\",\n normalize: bool = True,\n return_mask: bool = False,\n return_keypoint: bool = False,\n return_obb: bool = False,\n mask_ratio: int = 4,\n mask_overlap: bool = True,\n batch_idx: bool = True,\n bgr: float = 0.0,\n ):\n \"\"\"\n Initialize the Format class with given parameters for image and instance annotation formatting.\n\n This class standardizes image and instance annotations for object detection, instance segmentation, and pose\n estimation tasks, preparing them for use in PyTorch DataLoader's `collate_fn`.\n\n Args:\n bbox_format (str): Format for bounding boxes. Options are 'xywh', 'xyxy', etc.\n normalize (bool): Whether to normalize bounding boxes to [0,1].\n return_mask (bool): If True, returns instance masks for segmentation tasks.\n return_keypoint (bool): If True, returns keypoints for pose estimation tasks.\n return_obb (bool): If True, returns oriented bounding boxes.\n mask_ratio (int): Downsample ratio for masks.\n mask_overlap (bool): If True, allows mask overlap.\n batch_idx (bool): If True, keeps batch indexes.\n bgr (float): Probability of returning BGR images instead of RGB.\n\n Attributes:\n bbox_format (str): Format for bounding boxes.\n normalize (bool): Whether bounding boxes are normalized.\n return_mask (bool): Whether to return instance masks.\n return_keypoint (bool): Whether to return keypoints.\n return_obb (bool): Whether to return oriented bounding boxes.\n mask_ratio (int): Downsample ratio for masks.\n mask_overlap (bool): Whether masks can overlap.\n batch_idx (bool): Whether to keep batch indexes.\n bgr (float): The probability to return BGR images.\n\n Examples:\n >>> format = Format(bbox_format=\"xyxy\", return_mask=True, return_keypoint=False)\n >>> print(format.bbox_format)\n xyxy\n \"\"\"\n self.bbox_format = bbox_format\n self.normalize = normalize\n self.return_mask = return_mask # set False when training detection only\n self.return_keypoint = return_keypoint\n self.return_obb = return_obb\n self.mask_ratio = mask_ratio\n self.mask_overlap = mask_overlap\n self.batch_idx = batch_idx # keep the batch indexes\n self.bgr = bgr\n\n def __call__(self, labels: Dict[str, Any]) -> Dict[str, Any]:\n \"\"\"\n Format image annotations for object detection, instance segmentation, and pose estimation tasks.\n\n This method standardizes the image and instance annotations to be used by the `collate_fn` in PyTorch\n DataLoader. It processes the input labels dictionary, converting annotations to the specified format and\n applying normalization if required.\n\n Args:\n labels (Dict[str, Any]): A dictionary containing image and annotation data with the following keys:\n - 'img': The input image as a numpy array.\n - 'cls': Class labels for instances.\n - 'instances': An Instances object containing bounding boxes, segments, and keypoints.\n\n Returns:\n (Dict[str, Any]): A dictionary with formatted data, including:\n - 'img': Formatted image tensor.\n - 'cls': Class label's tensor.\n - 'bboxes': Bounding boxes tensor in the specified format.\n - 'masks': Instance masks tensor (if return_mask is True).\n - 'keypoints': Keypoints tensor (if return_keypoint is True).\n - 'batch_idx': Batch index tensor (if batch_idx is True).\n\n Examples:\n >>> formatter = Format(bbox_format=\"xywh\", normalize=True, return_mask=True)\n >>> labels = {\"img\": np.random.rand(640, 640, 3), \"cls\": np.array([0, 1]), \"instances\": Instances(...)}\n >>> formatted_labels = formatter(labels)\n >>> print(formatted_labels.keys())\n \"\"\"\n img = labels.pop(\"img\")\n h, w = img.shape[:2]\n cls = labels.pop(\"cls\")\n instances = labels.pop(\"instances\")\n instances.convert_bbox(format=self.bbox_format)\n instances.denormalize(w, h)\n nl = len(instances)\n\n if self.return_mask:\n if nl:\n masks, instances, cls = self._format_segments(instances, cls, w, h)\n masks = torch.from_numpy(masks)\n else:\n masks = torch.zeros(\n 1 if self.mask_overlap else nl, img.shape[0] // self.mask_ratio, img.shape[1] // self.mask_ratio\n )\n labels[\"masks\"] = masks\n labels[\"img\"] = self._format_img(img)\n labels[\"cls\"] = torch.from_numpy(cls) if nl else torch.zeros(nl)\n labels[\"bboxes\"] = torch.from_numpy(instances.bboxes) if nl else torch.zeros((nl, 4))\n if self.return_keypoint:\n labels[\"keypoints\"] = (\n torch.empty(0, 3) if instances.keypoints is None else torch.from_numpy(instances.keypoints)\n )\n if self.normalize:\n labels[\"keypoints\"][..., 0] /= w\n labels[\"keypoints\"][..., 1] /= h\n if self.return_obb:\n labels[\"bboxes\"] = (\n xyxyxyxy2xywhr(torch.from_numpy(instances.segments)) if len(instances.segments) else torch.zeros((0, 5))\n )\n # NOTE: need to normalize obb in xywhr format for width-height consistency\n if self.normalize:\n labels[\"bboxes\"][:, [0, 2]] /= w\n labels[\"bboxes\"][:, [1, 3]] /= h\n # Then we can use collate_fn\n if self.batch_idx:\n labels[\"batch_idx\"] = torch.zeros(nl)\n return labels\n\n def _format_img(self, img: np.ndarray) -> torch.Tensor:\n \"\"\"\n Format an image for YOLO from a Numpy array to a PyTorch tensor.\n\n This function performs the following operations:\n 1. Ensures the image has 3 dimensions (adds a channel dimension if needed).\n 2. Transposes the image from HWC to CHW format.\n 3. Optionally flips the color channels from RGB to BGR.\n 4. Converts the image to a contiguous array.\n 5. Converts the Numpy array to a PyTorch tensor.\n\n Args:\n img (np.ndarray): Input image as a Numpy array with shape (H, W, C) or (H, W).\n\n Returns:\n (torch.Tensor): Formatted image as a PyTorch tensor with shape (C, H, W).\n\n Examples:\n >>> import numpy as np\n >>> img = np.random.rand(100, 100, 3)\n >>> formatted_img = self._format_img(img)\n >>> print(formatted_img.shape)\n torch.Size([3, 100, 100])\n \"\"\"\n if len(img.shape) < 3:\n img = np.expand_dims(img, -1)\n img = img.transpose(2, 0, 1)\n img = np.ascontiguousarray(img[::-1] if random.uniform(0, 1) > self.bgr and img.shape[0] == 3 else img)\n img = torch.from_numpy(img)\n return img\n\n def _format_segments(\n self, instances: Instances, cls: np.ndarray, w: int, h: int\n ) -> Tuple[np.ndarray, Instances, np.ndarray]:\n \"\"\"\n Convert polygon segments to bitmap masks.\n\n Args:\n instances (Instances): Object containing segment information.\n cls (np.ndarray): Class labels for each instance.\n w (int): Width of the image.\n h (int): Height of the image.\n\n Returns:\n masks (np.ndarray): Bitmap masks with shape (N, H, W) or (1, H, W) if mask_overlap is True.\n instances (Instances): Updated instances object with sorted segments if mask_overlap is True.\n cls (np.ndarray): Updated class labels, sorted if mask_overlap is True.\n\n Notes:\n - If self.mask_overlap is True, masks are overlapped and sorted by area.\n - If self.mask_overlap is False, each mask is represented separately.\n - Masks are downsampled according to self.mask_ratio.\n \"\"\"\n segments = instances.segments\n if self.mask_overlap:\n masks, sorted_idx = polygons2masks_overlap((h, w), segments, downsample_ratio=self.mask_ratio)\n masks = masks[None] # (640, 640) -> (1, 640, 640)\n instances = instances[sorted_idx]\n cls = cls[sorted_idx]\n else:\n masks = polygons2masks((h, w), segments, color=1, downsample_ratio=self.mask_ratio)\n\n return masks, instances, cls", "chunk_type": "class", "name": "Format", "file_path": "ultralytics\\ultralytics\\data\\augment.py", "start_line": 2061, "end_line": 2277, "start_col": 0, "end_col": 36, "parent_name": null, "docstring": "A class for formatting image annotations for object detection, instance segmentation, and pose estimation tasks.\n\nThis class standardizes image and instance annotations to be used by the `collate_fn` in PyTorch DataLoader.\n\nAttributes:\n bbox_format (str): Format for bounding boxes. Options are 'xywh' or 'xyxy'.\n normalize (bool): Whether to normalize bounding boxes.\n return_mask (bool): Whether to return instance masks for segmentation.\n return_keypoint (bool): Whether to return keypoints for pose estimation.\n return_obb (bool): Whether to return oriented bounding boxes.\n mask_ratio (int): Downsample ratio for masks.\n mask_overlap (bool): Whether to overlap masks.\n batch_idx (bool): Whether to keep batch indexes.\n bgr (float): The probability to return BGR images.\n\nMethods:\n __call__: Format labels dictionary with image, classes, bounding boxes, and optionally masks and keypoints.\n _format_img: Convert image from Numpy array to PyTorch tensor.\n _format_segments: Convert polygon points to bitmap masks.\n\nExamples:\n >>> formatter = Format(bbox_format=\"xywh\", normalize=True, return_mask=True)\n >>> formatted_labels = formatter(labels)\n >>> img = formatted_labels[\"img\"]\n >>> bboxes = formatted_labels[\"bboxes\"]\n >>> masks = formatted_labels[\"masks\"]", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "math", "random", "copy.deepcopy", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "typing.Union", "cv2", "numpy", "torch", "PIL.Image", "torch.nn.functional", "ultralytics.data.utils.polygons2masks", "ultralytics.data.utils.polygons2masks_overlap", "ultralytics.utils.LOGGER", "ultralytics.utils.IterableSimpleNamespace", "ultralytics.utils.colorstr", "ultralytics.utils.checks.check_version", "ultralytics.utils.instance.Instances", "ultralytics.utils.metrics.bbox_ioa", "ultralytics.utils.ops.segment2box", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xyxyxyxy2xywhr", "ultralytics.utils.torch_utils.TORCHVISION_0_10", "ultralytics.utils.torch_utils.TORCHVISION_0_11", "ultralytics.utils.torch_utils.TORCHVISION_0_13", "torchvision.transforms", "torchvision.transforms", "os", "albumentations" ], "chunk_id": "class_Format_d7345388" }, { "content": "class LoadVisualPrompt:\n \"\"\"Create visual prompts from bounding boxes or masks for model input.\"\"\"\n\n def __init__(self, scale_factor: float = 1 / 8) -> None:\n \"\"\"\n Initialize the LoadVisualPrompt with a scale factor.\n\n Args:\n scale_factor (float): Factor to scale the input image dimensions.\n \"\"\"\n self.scale_factor = scale_factor\n\n def make_mask(self, boxes: torch.Tensor, h: int, w: int) -> torch.Tensor:\n \"\"\"\n Create binary masks from bounding boxes.\n\n Args:\n boxes (torch.Tensor): Bounding boxes in xyxy format, shape: (N, 4).\n h (int): Height of the mask.\n w (int): Width of the mask.\n\n Returns:\n (torch.Tensor): Binary masks with shape (N, h, w).\n \"\"\"\n x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1) # x1 shape(n,1,1)\n r = torch.arange(w)[None, None, :] # rows shape(1,1,w)\n c = torch.arange(h)[None, :, None] # cols shape(1,h,1)\n\n return (r >= x1) * (r < x2) * (c >= y1) * (c < y2)\n\n def __call__(self, labels: Dict[str, Any]) -> Dict[str, Any]:\n \"\"\"\n Process labels to create visual prompts.\n\n Args:\n labels (Dict[str, Any]): Dictionary containing image data and annotations.\n\n Returns:\n (Dict[str, Any]): Updated labels with visual prompts added.\n \"\"\"\n imgsz = labels[\"img\"].shape[1:]\n bboxes, masks = None, None\n if \"bboxes\" in labels:\n bboxes = labels[\"bboxes\"]\n bboxes = xywh2xyxy(bboxes) * torch.tensor(imgsz)[[1, 0, 1, 0]] # denormalize boxes\n\n cls = labels[\"cls\"].squeeze(-1).to(torch.int)\n visuals = self.get_visuals(cls, imgsz, bboxes=bboxes, masks=masks)\n labels[\"visuals\"] = visuals\n return labels\n\n def get_visuals(\n self,\n category: Union[int, np.ndarray, torch.Tensor],\n shape: Tuple[int, int],\n bboxes: Union[np.ndarray, torch.Tensor] = None,\n masks: Union[np.ndarray, torch.Tensor] = None,\n ) -> torch.Tensor:\n \"\"\"\n Generate visual masks based on bounding boxes or masks.\n\n Args:\n category (int | np.ndarray | torch.Tensor): The category labels for the objects.\n shape (Tuple[int, int]): The shape of the image (height, width).\n bboxes (np.ndarray | torch.Tensor, optional): Bounding boxes for the objects, xyxy format.\n masks (np.ndarray | torch.Tensor, optional): Masks for the objects.\n\n Returns:\n (torch.Tensor): A tensor containing the visual masks for each category.\n\n Raises:\n ValueError: If neither bboxes nor masks are provided.\n \"\"\"\n masksz = (int(shape[0] * self.scale_factor), int(shape[1] * self.scale_factor))\n if bboxes is not None:\n if isinstance(bboxes, np.ndarray):\n bboxes = torch.from_numpy(bboxes)\n bboxes *= self.scale_factor\n masks = self.make_mask(bboxes, *masksz).float()\n elif masks is not None:\n if isinstance(masks, np.ndarray):\n masks = torch.from_numpy(masks) # (N, H, W)\n masks = F.interpolate(masks.unsqueeze(1), masksz, mode=\"nearest\").squeeze(1).float()\n else:\n raise ValueError(\"LoadVisualPrompt must have bboxes or masks in the label\")\n if not isinstance(category, torch.Tensor):\n category = torch.tensor(category, dtype=torch.int)\n cls_unique, inverse_indices = torch.unique(category, sorted=True, return_inverse=True)\n # NOTE: `cls` indices from RandomLoadText should be continuous.\n # if len(cls_unique):\n # assert len(cls_unique) == cls_unique[-1] + 1, (\n # f\"Expected a continuous range of class indices, but got {cls_unique}\"\n # )\n visuals = torch.zeros(len(cls_unique), *masksz)\n for idx, mask in zip(inverse_indices, masks):\n visuals[idx] = torch.logical_or(visuals[idx], mask)\n return visuals", "chunk_type": "class", "name": "LoadVisualPrompt", "file_path": "ultralytics\\ultralytics\\data\\augment.py", "start_line": 2280, "end_line": 2376, "start_col": 0, "end_col": 22, "parent_name": null, "docstring": "Create visual prompts from bounding boxes or masks for model input.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "math", "random", "copy.deepcopy", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "typing.Union", "cv2", "numpy", "torch", "PIL.Image", "torch.nn.functional", "ultralytics.data.utils.polygons2masks", "ultralytics.data.utils.polygons2masks_overlap", "ultralytics.utils.LOGGER", "ultralytics.utils.IterableSimpleNamespace", "ultralytics.utils.colorstr", "ultralytics.utils.checks.check_version", "ultralytics.utils.instance.Instances", "ultralytics.utils.metrics.bbox_ioa", "ultralytics.utils.ops.segment2box", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xyxyxyxy2xywhr", "ultralytics.utils.torch_utils.TORCHVISION_0_10", "ultralytics.utils.torch_utils.TORCHVISION_0_11", "ultralytics.utils.torch_utils.TORCHVISION_0_13", "torchvision.transforms", "torchvision.transforms", "os", "albumentations" ], "chunk_id": "class_LoadVisualPrompt_0bbfddc1" }, { "content": "class RandomLoadText:\n \"\"\"\n Randomly sample positive and negative texts and update class indices accordingly.\n\n This class is responsible for sampling texts from a given set of class texts, including both positive\n (present in the image) and negative (not present in the image) samples. It updates the class indices\n to reflect the sampled texts and can optionally pad the text list to a fixed length.\n\n Attributes:\n prompt_format (str): Format string for text prompts.\n neg_samples (Tuple[int, int]): Range for randomly sampling negative texts.\n max_samples (int): Maximum number of different text samples in one image.\n padding (bool): Whether to pad texts to max_samples.\n padding_value (str): The text used for padding when padding is True.\n\n Methods:\n __call__: Process the input labels and return updated classes and texts.\n\n Examples:\n >>> loader = RandomLoadText(prompt_format=\"Object: {}\", neg_samples=(5, 10), max_samples=20)\n >>> labels = {\"cls\": [0, 1, 2], \"texts\": [[\"cat\"], [\"dog\"], [\"bird\"]], \"instances\": [...]}\n >>> updated_labels = loader(labels)\n >>> print(updated_labels[\"texts\"])\n ['Object: cat', 'Object: dog', 'Object: bird', 'Object: elephant', 'Object: car']\n \"\"\"\n\n def __init__(\n self,\n prompt_format: str = \"{}\",\n neg_samples: Tuple[int, int] = (80, 80),\n max_samples: int = 80,\n padding: bool = False,\n padding_value: List[str] = [\"\"],\n ) -> None:\n \"\"\"\n Initialize the RandomLoadText class for randomly sampling positive and negative texts.\n\n This class is designed to randomly sample positive texts and negative texts, and update the class\n indices accordingly to the number of samples. It can be used for text-based object detection tasks.\n\n Args:\n prompt_format (str): Format string for the prompt. The format string should\n contain a single pair of curly braces {} where the text will be inserted.\n neg_samples (Tuple[int, int]): A range to randomly sample negative texts. The first integer\n specifies the minimum number of negative samples, and the second integer specifies the\n maximum.\n max_samples (int): The maximum number of different text samples in one image.\n padding (bool): Whether to pad texts to max_samples. If True, the number of texts will always\n be equal to max_samples.\n padding_value (str): The padding text to use when padding is True.\n\n Attributes:\n prompt_format (str): The format string for the prompt.\n neg_samples (Tuple[int, int]): The range for sampling negative texts.\n max_samples (int): The maximum number of text samples.\n padding (bool): Whether padding is enabled.\n padding_value (str): The value used for padding.\n\n Examples:\n >>> random_load_text = RandomLoadText(prompt_format=\"Object: {}\", neg_samples=(50, 100), max_samples=120)\n >>> random_load_text.prompt_format\n 'Object: {}'\n >>> random_load_text.neg_samples\n (50, 100)\n >>> random_load_text.max_samples\n 120\n \"\"\"\n self.prompt_format = prompt_format\n self.neg_samples = neg_samples\n self.max_samples = max_samples\n self.padding = padding\n self.padding_value = padding_value\n\n def __call__(self, labels: Dict[str, Any]) -> Dict[str, Any]:\n \"\"\"\n Randomly sample positive and negative texts and update class indices accordingly.\n\n This method samples positive texts based on the existing class labels in the image, and randomly\n selects negative texts from the remaining classes. It then updates the class indices to match the\n new sampled text order.\n\n Args:\n labels (Dict[str, Any]): A dictionary containing image labels and metadata. Must include 'texts' and 'cls' keys.\n\n Returns:\n (Dict[str, Any]): Updated labels dictionary with new 'cls' and 'texts' entries.\n\n Examples:\n >>> loader = RandomLoadText(prompt_format=\"A photo of {}\", neg_samples=(5, 10), max_samples=20)\n >>> labels = {\"cls\": np.array([[0], [1], [2]]), \"texts\": [[\"dog\"], [\"cat\"], [\"bird\"]]}\n >>> updated_labels = loader(labels)\n \"\"\"\n assert \"texts\" in labels, \"No texts found in labels.\"\n class_texts = labels[\"texts\"]\n num_classes = len(class_texts)\n cls = np.asarray(labels.pop(\"cls\"), dtype=int)\n pos_labels = np.unique(cls).tolist()\n\n if len(pos_labels) > self.max_samples:\n pos_labels = random.sample(pos_labels, k=self.max_samples)\n\n neg_samples = min(min(num_classes, self.max_samples) - len(pos_labels), random.randint(*self.neg_samples))\n neg_labels = [i for i in range(num_classes) if i not in pos_labels]\n neg_labels = random.sample(neg_labels, k=neg_samples)\n\n sampled_labels = pos_labels + neg_labels\n # Randomness\n # random.shuffle(sampled_labels)\n\n label2ids = {label: i for i, label in enumerate(sampled_labels)}\n valid_idx = np.zeros(len(labels[\"instances\"]), dtype=bool)\n new_cls = []\n for i, label in enumerate(cls.squeeze(-1).tolist()):\n if label not in label2ids:\n continue\n valid_idx[i] = True\n new_cls.append([label2ids[label]])\n labels[\"instances\"] = labels[\"instances\"][valid_idx]\n labels[\"cls\"] = np.array(new_cls)\n\n # Randomly select one prompt when there's more than one prompts\n texts = []\n for label in sampled_labels:\n prompts = class_texts[label]\n assert len(prompts) > 0\n prompt = self.prompt_format.format(prompts[random.randrange(len(prompts))])\n texts.append(prompt)\n\n if self.padding:\n valid_labels = len(pos_labels) + len(neg_labels)\n num_padding = self.max_samples - valid_labels\n if num_padding > 0:\n texts += random.choices(self.padding_value, k=num_padding)\n\n assert len(texts) == self.max_samples\n labels[\"texts\"] = texts\n return labels", "chunk_type": "class", "name": "RandomLoadText", "file_path": "ultralytics\\ultralytics\\data\\augment.py", "start_line": 2379, "end_line": 2515, "start_col": 0, "end_col": 21, "parent_name": null, "docstring": "Randomly sample positive and negative texts and update class indices accordingly.\n\nThis class is responsible for sampling texts from a given set of class texts, including both positive\n(present in the image) and negative (not present in the image) samples. It updates the class indices\nto reflect the sampled texts and can optionally pad the text list to a fixed length.\n\nAttributes:\n prompt_format (str): Format string for text prompts.\n neg_samples (Tuple[int, int]): Range for randomly sampling negative texts.\n max_samples (int): Maximum number of different text samples in one image.\n padding (bool): Whether to pad texts to max_samples.\n padding_value (str): The text used for padding when padding is True.\n\nMethods:\n __call__: Process the input labels and return updated classes and texts.\n\nExamples:\n >>> loader = RandomLoadText(prompt_format=\"Object: {}\", neg_samples=(5, 10), max_samples=20)\n >>> labels = {\"cls\": [0, 1, 2], \"texts\": [[\"cat\"], [\"dog\"], [\"bird\"]], \"instances\": [...]}\n >>> updated_labels = loader(labels)\n >>> print(updated_labels[\"texts\"])\n ['Object: cat', 'Object: dog', 'Object: bird', 'Object: elephant', 'Object: car']", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "math", "random", "copy.deepcopy", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "typing.Union", "cv2", "numpy", "torch", "PIL.Image", "torch.nn.functional", "ultralytics.data.utils.polygons2masks", "ultralytics.data.utils.polygons2masks_overlap", "ultralytics.utils.LOGGER", "ultralytics.utils.IterableSimpleNamespace", "ultralytics.utils.colorstr", "ultralytics.utils.checks.check_version", "ultralytics.utils.instance.Instances", "ultralytics.utils.metrics.bbox_ioa", "ultralytics.utils.ops.segment2box", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xyxyxyxy2xywhr", "ultralytics.utils.torch_utils.TORCHVISION_0_10", "ultralytics.utils.torch_utils.TORCHVISION_0_11", "ultralytics.utils.torch_utils.TORCHVISION_0_13", "torchvision.transforms", "torchvision.transforms", "os", "albumentations" ], "chunk_id": "class_RandomLoadText_28c8bd0a" }, { "content": "def v8_transforms(dataset, imgsz: int, hyp: IterableSimpleNamespace, stretch: bool = False):\n \"\"\"\n Apply a series of image transformations for training.\n\n This function creates a composition of image augmentation techniques to prepare images for YOLO training.\n It includes operations such as mosaic, copy-paste, random perspective, mixup, and various color adjustments.\n\n Args:\n dataset (Dataset): The dataset object containing image data and annotations.\n imgsz (int): The target image size for resizing.\n hyp (IterableSimpleNamespace): A dictionary of hyperparameters controlling various aspects of the transformations.\n stretch (bool): If True, applies stretching to the image. If False, uses LetterBox resizing.\n\n Returns:\n (Compose): A composition of image transformations to be applied to the dataset.\n\n Examples:\n >>> from ultralytics.data.dataset import YOLODataset\n >>> from ultralytics.utils import IterableSimpleNamespace\n >>> dataset = YOLODataset(img_path=\"path/to/images\", imgsz=640)\n >>> hyp = IterableSimpleNamespace(mosaic=1.0, copy_paste=0.5, degrees=10.0, translate=0.2, scale=0.9)\n >>> transforms = v8_transforms(dataset, imgsz=640, hyp=hyp)\n >>> augmented_data = transforms(dataset[0])\n \"\"\"\n mosaic = Mosaic(dataset, imgsz=imgsz, p=hyp.mosaic)\n affine = RandomPerspective(\n degrees=hyp.degrees,\n translate=hyp.translate,\n scale=hyp.scale,\n shear=hyp.shear,\n perspective=hyp.perspective,\n pre_transform=None if stretch else LetterBox(new_shape=(imgsz, imgsz)),\n )\n\n pre_transform = Compose([mosaic, affine])\n if hyp.copy_paste_mode == \"flip\":\n pre_transform.insert(1, CopyPaste(p=hyp.copy_paste, mode=hyp.copy_paste_mode))\n else:\n pre_transform.append(\n CopyPaste(\n dataset,\n pre_transform=Compose([Mosaic(dataset, imgsz=imgsz, p=hyp.mosaic), affine]),\n p=hyp.copy_paste,\n mode=hyp.copy_paste_mode,\n )\n )\n flip_idx = dataset.data.get(\"flip_idx\", []) # for keypoints augmentation\n if dataset.use_keypoints:\n kpt_shape = dataset.data.get(\"kpt_shape\", None)\n if len(flip_idx) == 0 and (hyp.fliplr > 0.0 or hyp.flipud > 0.0):\n hyp.fliplr = hyp.flipud = 0.0 # both fliplr and flipud require flip_idx\n LOGGER.warning(\"No 'flip_idx' array defined in data.yaml, disabling 'fliplr' and 'flipud' augmentations.\")\n elif flip_idx and (len(flip_idx) != kpt_shape[0]):\n raise ValueError(f\"data.yaml flip_idx={flip_idx} length must be equal to kpt_shape[0]={kpt_shape[0]}\")\n\n return Compose(\n [\n pre_transform,\n MixUp(dataset, pre_transform=pre_transform, p=hyp.mixup),\n CutMix(dataset, pre_transform=pre_transform, p=hyp.cutmix),\n Albumentations(p=1.0),\n RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v),\n RandomFlip(direction=\"vertical\", p=hyp.flipud, flip_idx=flip_idx),\n RandomFlip(direction=\"horizontal\", p=hyp.fliplr, flip_idx=flip_idx),\n ]\n ) # transforms", "chunk_type": "function", "name": "v8_transforms", "file_path": "ultralytics\\ultralytics\\data\\augment.py", "start_line": 2518, "end_line": 2583, "start_col": 0, "end_col": 5, "parent_name": null, "docstring": "Apply a series of image transformations for training.\n\nThis function creates a composition of image augmentation techniques to prepare images for YOLO training.\nIt includes operations such as mosaic, copy-paste, random perspective, mixup, and various color adjustments.\n\nArgs:\n dataset (Dataset): The dataset object containing image data and annotations.\n imgsz (int): The target image size for resizing.\n hyp (IterableSimpleNamespace): A dictionary of hyperparameters controlling various aspects of the transformations.\n stretch (bool): If True, applies stretching to the image. If False, uses LetterBox resizing.\n\nReturns:\n (Compose): A composition of image transformations to be applied to the dataset.\n\nExamples:\n >>> from ultralytics.data.dataset import YOLODataset\n >>> from ultralytics.utils import IterableSimpleNamespace\n >>> dataset = YOLODataset(img_path=\"path/to/images\", imgsz=640)\n >>> hyp = IterableSimpleNamespace(mosaic=1.0, copy_paste=0.5, degrees=10.0, translate=0.2, scale=0.9)\n >>> transforms = v8_transforms(dataset, imgsz=640, hyp=hyp)\n >>> augmented_data = transforms(dataset[0])", "parameters": [ "dataset", "imgsz: int", "hyp: IterableSimpleNamespace", "stretch: bool" ], "return_type": null, "decorators": [], "complexity_score": 5, "dependencies": [ "math", "random", "copy.deepcopy", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "typing.Union", "cv2", "numpy", "torch", "PIL.Image", "torch.nn.functional", "ultralytics.data.utils.polygons2masks", "ultralytics.data.utils.polygons2masks_overlap", "ultralytics.utils.LOGGER", "ultralytics.utils.IterableSimpleNamespace", "ultralytics.utils.colorstr", "ultralytics.utils.checks.check_version", "ultralytics.utils.instance.Instances", "ultralytics.utils.metrics.bbox_ioa", "ultralytics.utils.ops.segment2box", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xyxyxyxy2xywhr", "ultralytics.utils.torch_utils.TORCHVISION_0_10", "ultralytics.utils.torch_utils.TORCHVISION_0_11", "ultralytics.utils.torch_utils.TORCHVISION_0_13", "torchvision.transforms", "torchvision.transforms", "os", "albumentations" ], "chunk_id": "function_v8_transforms_e5d75732" }, { "content": "def classify_transforms(\n size: Union[Tuple[int, int], int] = 224,\n mean: Tuple[float, float, float] = DEFAULT_MEAN,\n std: Tuple[float, float, float] = DEFAULT_STD,\n interpolation: str = \"BILINEAR\",\n crop_fraction: float = None,\n):\n \"\"\"\n Create a composition of image transforms for classification tasks.\n\n This function generates a sequence of torchvision transforms suitable for preprocessing images\n for classification models during evaluation or inference. The transforms include resizing,\n center cropping, conversion to tensor, and normalization.\n\n Args:\n size (int | tuple): The target size for the transformed image. If an int, it defines the shortest edge. If a\n tuple, it defines (height, width).\n mean (Tuple[float, float, float]): Mean values for each RGB channel used in normalization.\n std (Tuple[float, float, float]): Standard deviation values for each RGB channel used in normalization.\n interpolation (str): Interpolation method of either 'NEAREST', 'BILINEAR' or 'BICUBIC'.\n crop_fraction (float): Deprecated, will be removed in a future version.\n\n Returns:\n (torchvision.transforms.Compose): A composition of torchvision transforms.\n\n Examples:\n >>> transforms = classify_transforms(size=224)\n >>> img = Image.open(\"path/to/image.jpg\")\n >>> transformed_img = transforms(img)\n \"\"\"\n import torchvision.transforms as T # scope for faster 'import ultralytics'\n\n scale_size = size if isinstance(size, (tuple, list)) and len(size) == 2 else (size, size)\n\n if crop_fraction:\n raise DeprecationWarning(\n \"'crop_fraction' arg of classify_transforms is deprecated, will be removed in a future version.\"\n )\n\n # Aspect ratio is preserved, crops center within image, no borders are added, image is lost\n if scale_size[0] == scale_size[1]:\n # Simple case, use torchvision built-in Resize with the shortest edge mode (scalar size arg)\n tfl = [T.Resize(scale_size[0], interpolation=getattr(T.InterpolationMode, interpolation))]\n else:\n # Resize the shortest edge to matching target dim for non-square target\n tfl = [T.Resize(scale_size)]\n tfl += [T.CenterCrop(size), T.ToTensor(), T.Normalize(mean=torch.tensor(mean), std=torch.tensor(std))]\n return T.Compose(tfl)", "chunk_type": "function", "name": "classify_transforms", "file_path": "ultralytics\\ultralytics\\data\\augment.py", "start_line": 2587, "end_line": 2634, "start_col": 0, "end_col": 25, "parent_name": null, "docstring": "Create a composition of image transforms for classification tasks.\n\nThis function generates a sequence of torchvision transforms suitable for preprocessing images\nfor classification models during evaluation or inference. The transforms include resizing,\ncenter cropping, conversion to tensor, and normalization.\n\nArgs:\n size (int | tuple): The target size for the transformed image. If an int, it defines the shortest edge. If a\n tuple, it defines (height, width).\n mean (Tuple[float, float, float]): Mean values for each RGB channel used in normalization.\n std (Tuple[float, float, float]): Standard deviation values for each RGB channel used in normalization.\n interpolation (str): Interpolation method of either 'NEAREST', 'BILINEAR' or 'BICUBIC'.\n crop_fraction (float): Deprecated, will be removed in a future version.\n\nReturns:\n (torchvision.transforms.Compose): A composition of torchvision transforms.\n\nExamples:\n >>> transforms = classify_transforms(size=224)\n >>> img = Image.open(\"path/to/image.jpg\")\n >>> transformed_img = transforms(img)", "parameters": [ "size: Union[Tuple[int, int], int]", "mean: Tuple[float, float, float]", "std: Tuple[float, float, float]", "interpolation: str", "crop_fraction: float" ], "return_type": null, "decorators": [], "complexity_score": 3, "dependencies": [ "math", "random", "copy.deepcopy", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "typing.Union", "cv2", "numpy", "torch", "PIL.Image", "torch.nn.functional", "ultralytics.data.utils.polygons2masks", "ultralytics.data.utils.polygons2masks_overlap", "ultralytics.utils.LOGGER", "ultralytics.utils.IterableSimpleNamespace", "ultralytics.utils.colorstr", "ultralytics.utils.checks.check_version", "ultralytics.utils.instance.Instances", "ultralytics.utils.metrics.bbox_ioa", "ultralytics.utils.ops.segment2box", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xyxyxyxy2xywhr", "ultralytics.utils.torch_utils.TORCHVISION_0_10", "ultralytics.utils.torch_utils.TORCHVISION_0_11", "ultralytics.utils.torch_utils.TORCHVISION_0_13", "torchvision.transforms", "torchvision.transforms", "os", "albumentations" ], "chunk_id": "function_classify_transforms_37d802eb" }, { "content": "def classify_augmentations(\n size: int = 224,\n mean: Tuple[float, float, float] = DEFAULT_MEAN,\n std: Tuple[float, float, float] = DEFAULT_STD,\n scale: Tuple[float, float] = None,\n ratio: Tuple[float, float] = None,\n hflip: float = 0.5,\n vflip: float = 0.0,\n auto_augment: str = None,\n hsv_h: float = 0.015, # image HSV-Hue augmentation (fraction)\n hsv_s: float = 0.4, # image HSV-Saturation augmentation (fraction)\n hsv_v: float = 0.4, # image HSV-Value augmentation (fraction)\n force_color_jitter: bool = False,\n erasing: float = 0.0,\n interpolation: str = \"BILINEAR\",\n):\n \"\"\"\n Create a composition of image augmentation transforms for classification tasks.\n\n This function generates a set of image transformations suitable for training classification models. It includes\n options for resizing, flipping, color jittering, auto augmentation, and random erasing.\n\n Args:\n size (int): Target size for the image after transformations.\n mean (Tuple[float, float, float]): Mean values for each RGB channel used in normalization.\n std (Tuple[float, float, float]): Standard deviation values for each RGB channel used in normalization.\n scale (Tuple[float, float] | None): Range of size of the origin size cropped.\n ratio (Tuple[float, float] | None): Range of aspect ratio of the origin aspect ratio cropped.\n hflip (float): Probability of horizontal flip.\n vflip (float): Probability of vertical flip.\n auto_augment (str | None): Auto augmentation policy. Can be 'randaugment', 'augmix', 'autoaugment' or None.\n hsv_h (float): Image HSV-Hue augmentation factor.\n hsv_s (float): Image HSV-Saturation augmentation factor.\n hsv_v (float): Image HSV-Value augmentation factor.\n force_color_jitter (bool): Whether to apply color jitter even if auto augment is enabled.\n erasing (float): Probability of random erasing.\n interpolation (str): Interpolation method of either 'NEAREST', 'BILINEAR' or 'BICUBIC'.\n\n Returns:\n (torchvision.transforms.Compose): A composition of image augmentation transforms.\n\n Examples:\n >>> transforms = classify_augmentations(size=224, auto_augment=\"randaugment\")\n >>> augmented_image = transforms(original_image)\n \"\"\"\n # Transforms to apply if Albumentations not installed\n import torchvision.transforms as T # scope for faster 'import ultralytics'\n\n if not isinstance(size, int):\n raise TypeError(f\"classify_augmentations() size {size} must be integer, not (list, tuple)\")\n scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range\n ratio = tuple(ratio or (3.0 / 4.0, 4.0 / 3.0)) # default imagenet ratio range\n interpolation = getattr(T.InterpolationMode, interpolation)\n primary_tfl = [T.RandomResizedCrop(size, scale=scale, ratio=ratio, interpolation=interpolation)]\n if hflip > 0.0:\n primary_tfl.append(T.RandomHorizontalFlip(p=hflip))\n if vflip > 0.0:\n primary_tfl.append(T.RandomVerticalFlip(p=vflip))\n\n secondary_tfl = []\n disable_color_jitter = False\n if auto_augment:\n assert isinstance(auto_augment, str), f\"Provided argument should be string, but got type {type(auto_augment)}\"\n # color jitter is typically disabled if AA/RA on,\n # this allows override without breaking old hparm cfgs\n disable_color_jitter = not force_color_jitter\n\n if auto_augment == \"randaugment\":\n if TORCHVISION_0_11:\n secondary_tfl.append(T.RandAugment(interpolation=interpolation))\n else:\n LOGGER.warning('\"auto_augment=randaugment\" requires torchvision >= 0.11.0. Disabling it.')\n\n elif auto_augment == \"augmix\":\n if TORCHVISION_0_13:\n secondary_tfl.append(T.AugMix(interpolation=interpolation))\n else:\n LOGGER.warning('\"auto_augment=augmix\" requires torchvision >= 0.13.0. Disabling it.')\n\n elif auto_augment == \"autoaugment\":\n if TORCHVISION_0_10:\n secondary_tfl.append(T.AutoAugment(interpolation=interpolation))\n else:\n LOGGER.warning('\"auto_augment=autoaugment\" requires torchvision >= 0.10.0. Disabling it.')\n\n else:\n raise ValueError(\n f'Invalid auto_augment policy: {auto_augment}. Should be one of \"randaugment\", '\n f'\"augmix\", \"autoaugment\" or None'\n )\n\n if not disable_color_jitter:\n secondary_tfl.append(T.ColorJitter(brightness=hsv_v, contrast=hsv_v, saturation=hsv_s, hue=hsv_h))\n\n final_tfl = [\n T.ToTensor(),\n T.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)),\n T.RandomErasing(p=erasing, inplace=True),\n ]\n\n return T.Compose(primary_tfl + secondary_tfl + final_tfl)", "chunk_type": "function", "name": "classify_augmentations", "file_path": "ultralytics\\ultralytics\\data\\augment.py", "start_line": 2638, "end_line": 2738, "start_col": 0, "end_col": 61, "parent_name": null, "docstring": "Create a composition of image augmentation transforms for classification tasks.\n\nThis function generates a set of image transformations suitable for training classification models. It includes\noptions for resizing, flipping, color jittering, auto augmentation, and random erasing.\n\nArgs:\n size (int): Target size for the image after transformations.\n mean (Tuple[float, float, float]): Mean values for each RGB channel used in normalization.\n std (Tuple[float, float, float]): Standard deviation values for each RGB channel used in normalization.\n scale (Tuple[float, float] | None): Range of size of the origin size cropped.\n ratio (Tuple[float, float] | None): Range of aspect ratio of the origin aspect ratio cropped.\n hflip (float): Probability of horizontal flip.\n vflip (float): Probability of vertical flip.\n auto_augment (str | None): Auto augmentation policy. Can be 'randaugment', 'augmix', 'autoaugment' or None.\n hsv_h (float): Image HSV-Hue augmentation factor.\n hsv_s (float): Image HSV-Saturation augmentation factor.\n hsv_v (float): Image HSV-Value augmentation factor.\n force_color_jitter (bool): Whether to apply color jitter even if auto augment is enabled.\n erasing (float): Probability of random erasing.\n interpolation (str): Interpolation method of either 'NEAREST', 'BILINEAR' or 'BICUBIC'.\n\nReturns:\n (torchvision.transforms.Compose): A composition of image augmentation transforms.\n\nExamples:\n >>> transforms = classify_augmentations(size=224, auto_augment=\"randaugment\")\n >>> augmented_image = transforms(original_image)", "parameters": [ "size: int", "mean: Tuple[float, float, float]", "std: Tuple[float, float, float]", "scale: Tuple[float, float]", "ratio: Tuple[float, float]", "hflip: float", "vflip: float", "auto_augment: str", "hsv_h: float", "hsv_s: float", "hsv_v: float", "force_color_jitter: bool", "erasing: float", "interpolation: str" ], "return_type": null, "decorators": [], "complexity_score": 12, "dependencies": [ "math", "random", "copy.deepcopy", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "typing.Union", "cv2", "numpy", "torch", "PIL.Image", "torch.nn.functional", "ultralytics.data.utils.polygons2masks", "ultralytics.data.utils.polygons2masks_overlap", "ultralytics.utils.LOGGER", "ultralytics.utils.IterableSimpleNamespace", "ultralytics.utils.colorstr", "ultralytics.utils.checks.check_version", "ultralytics.utils.instance.Instances", "ultralytics.utils.metrics.bbox_ioa", "ultralytics.utils.ops.segment2box", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xyxyxyxy2xywhr", "ultralytics.utils.torch_utils.TORCHVISION_0_10", "ultralytics.utils.torch_utils.TORCHVISION_0_11", "ultralytics.utils.torch_utils.TORCHVISION_0_13", "torchvision.transforms", "torchvision.transforms", "os", "albumentations" ], "chunk_id": "function_classify_augmentations_40ee6c89" }, { "content": "class ClassifyLetterBox:\n \"\"\"\n A class for resizing and padding images for classification tasks.\n\n This class is designed to be part of a transformation pipeline, e.g., T.Compose([LetterBox(size), ToTensor()]).\n It resizes and pads images to a specified size while maintaining the original aspect ratio.\n\n Attributes:\n h (int): Target height of the image.\n w (int): Target width of the image.\n auto (bool): If True, automatically calculates the short side using stride.\n stride (int): The stride value, used when 'auto' is True.\n\n Methods:\n __call__: Apply the letterbox transformation to an input image.\n\n Examples:\n >>> transform = ClassifyLetterBox(size=(640, 640), auto=False, stride=32)\n >>> img = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8)\n >>> result = transform(img)\n >>> print(result.shape)\n (640, 640, 3)\n \"\"\"\n\n def __init__(self, size: Union[int, Tuple[int, int]] = (640, 640), auto: bool = False, stride: int = 32):\n \"\"\"\n Initialize the ClassifyLetterBox object for image preprocessing.\n\n This class is designed to be part of a transformation pipeline for image classification tasks. It resizes and\n pads images to a specified size while maintaining the original aspect ratio.\n\n Args:\n size (int | Tuple[int, int]): Target size for the letterboxed image. If an int, a square image of\n (size, size) is created. If a tuple, it should be (height, width).\n auto (bool): If True, automatically calculates the short side based on stride.\n stride (int): The stride value, used when 'auto' is True.\n\n Attributes:\n h (int): Target height of the letterboxed image.\n w (int): Target width of the letterboxed image.\n auto (bool): Flag indicating whether to automatically calculate short side.\n stride (int): Stride value for automatic short side calculation.\n\n Examples:\n >>> transform = ClassifyLetterBox(size=224)\n >>> img = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8)\n >>> result = transform(img)\n >>> print(result.shape)\n (224, 224, 3)\n \"\"\"\n super().__init__()\n self.h, self.w = (size, size) if isinstance(size, int) else size\n self.auto = auto # pass max size integer, automatically solve for short side using stride\n self.stride = stride # used with auto\n\n def __call__(self, im: np.ndarray) -> np.ndarray:\n \"\"\"\n Resize and pad an image using the letterbox method.\n\n This method resizes the input image to fit within the specified dimensions while maintaining its aspect ratio,\n then pads the resized image to match the target size.\n\n Args:\n im (np.ndarray): Input image as a numpy array with shape (H, W, C).\n\n Returns:\n (np.ndarray): Resized and padded image as a numpy array with shape (hs, ws, 3), where hs and ws are\n the target height and width respectively.\n\n Examples:\n >>> letterbox = ClassifyLetterBox(size=(640, 640))\n >>> image = np.random.randint(0, 255, (720, 1280, 3), dtype=np.uint8)\n >>> resized_image = letterbox(image)\n >>> print(resized_image.shape)\n (640, 640, 3)\n \"\"\"\n imh, imw = im.shape[:2]\n r = min(self.h / imh, self.w / imw) # ratio of new/old dimensions\n h, w = round(imh * r), round(imw * r) # resized image dimensions\n\n # Calculate padding dimensions\n hs, ws = (math.ceil(x / self.stride) * self.stride for x in (h, w)) if self.auto else (self.h, self.w)\n top, left = round((hs - h) / 2 - 0.1), round((ws - w) / 2 - 0.1)\n\n # Create padded image\n im_out = np.full((hs, ws, 3), 114, dtype=im.dtype)\n im_out[top : top + h, left : left + w] = cv2.resize(im, (w, h), interpolation=cv2.INTER_LINEAR)\n return im_out", "chunk_type": "class", "name": "ClassifyLetterBox", "file_path": "ultralytics\\ultralytics\\data\\augment.py", "start_line": 2742, "end_line": 2829, "start_col": 0, "end_col": 21, "parent_name": null, "docstring": "A class for resizing and padding images for classification tasks.\n\nThis class is designed to be part of a transformation pipeline, e.g., T.Compose([LetterBox(size), ToTensor()]).\nIt resizes and pads images to a specified size while maintaining the original aspect ratio.\n\nAttributes:\n h (int): Target height of the image.\n w (int): Target width of the image.\n auto (bool): If True, automatically calculates the short side using stride.\n stride (int): The stride value, used when 'auto' is True.\n\nMethods:\n __call__: Apply the letterbox transformation to an input image.\n\nExamples:\n >>> transform = ClassifyLetterBox(size=(640, 640), auto=False, stride=32)\n >>> img = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8)\n >>> result = transform(img)\n >>> print(result.shape)\n (640, 640, 3)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "math", "random", "copy.deepcopy", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "typing.Union", "cv2", "numpy", "torch", "PIL.Image", "torch.nn.functional", "ultralytics.data.utils.polygons2masks", "ultralytics.data.utils.polygons2masks_overlap", "ultralytics.utils.LOGGER", "ultralytics.utils.IterableSimpleNamespace", "ultralytics.utils.colorstr", "ultralytics.utils.checks.check_version", "ultralytics.utils.instance.Instances", "ultralytics.utils.metrics.bbox_ioa", "ultralytics.utils.ops.segment2box", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xyxyxyxy2xywhr", "ultralytics.utils.torch_utils.TORCHVISION_0_10", "ultralytics.utils.torch_utils.TORCHVISION_0_11", "ultralytics.utils.torch_utils.TORCHVISION_0_13", "torchvision.transforms", "torchvision.transforms", "os", "albumentations" ], "chunk_id": "class_ClassifyLetterBox_650fb3b3" }, { "content": "class CenterCrop:\n \"\"\"\n Apply center cropping to images for classification tasks.\n\n This class performs center cropping on input images, resizing them to a specified size while maintaining the aspect\n ratio. It is designed to be part of a transformation pipeline, e.g., T.Compose([CenterCrop(size), ToTensor()]).\n\n Attributes:\n h (int): Target height of the cropped image.\n w (int): Target width of the cropped image.\n\n Methods:\n __call__: Apply the center crop transformation to an input image.\n\n Examples:\n >>> transform = CenterCrop(640)\n >>> image = np.random.randint(0, 255, (1080, 1920, 3), dtype=np.uint8)\n >>> cropped_image = transform(image)\n >>> print(cropped_image.shape)\n (640, 640, 3)\n \"\"\"\n\n def __init__(self, size: Union[int, Tuple[int, int]] = (640, 640)):\n \"\"\"\n Initialize the CenterCrop object for image preprocessing.\n\n This class is designed to be part of a transformation pipeline, e.g., T.Compose([CenterCrop(size), ToTensor()]).\n It performs a center crop on input images to a specified size.\n\n Args:\n size (int | Tuple[int, int]): The desired output size of the crop. If size is an int, a square crop\n (size, size) is made. If size is a sequence like (h, w), it is used as the output size.\n\n Returns:\n (None): This method initializes the object and does not return anything.\n\n Examples:\n >>> transform = CenterCrop(224)\n >>> img = np.random.rand(300, 300, 3)\n >>> cropped_img = transform(img)\n >>> print(cropped_img.shape)\n (224, 224, 3)\n \"\"\"\n super().__init__()\n self.h, self.w = (size, size) if isinstance(size, int) else size\n\n def __call__(self, im: Union[Image.Image, np.ndarray]) -> np.ndarray:\n \"\"\"\n Apply center cropping to an input image.\n\n This method resizes and crops the center of the image using a letterbox method. It maintains the aspect\n ratio of the original image while fitting it into the specified dimensions.\n\n Args:\n im (np.ndarray | PIL.Image.Image): The input image as a numpy array of shape (H, W, C) or a\n PIL Image object.\n\n Returns:\n (np.ndarray): The center-cropped and resized image as a numpy array of shape (self.h, self.w, C).\n\n Examples:\n >>> transform = CenterCrop(size=224)\n >>> image = np.random.randint(0, 255, (640, 480, 3), dtype=np.uint8)\n >>> cropped_image = transform(image)\n >>> assert cropped_image.shape == (224, 224, 3)\n \"\"\"\n if isinstance(im, Image.Image): # convert from PIL to numpy array if required\n im = np.asarray(im)\n imh, imw = im.shape[:2]\n m = min(imh, imw) # min dimension\n top, left = (imh - m) // 2, (imw - m) // 2\n return cv2.resize(im[top : top + m, left : left + m], (self.w, self.h), interpolation=cv2.INTER_LINEAR)", "chunk_type": "class", "name": "CenterCrop", "file_path": "ultralytics\\ultralytics\\data\\augment.py", "start_line": 2833, "end_line": 2904, "start_col": 0, "end_col": 111, "parent_name": null, "docstring": "Apply center cropping to images for classification tasks.\n\nThis class performs center cropping on input images, resizing them to a specified size while maintaining the aspect\nratio. It is designed to be part of a transformation pipeline, e.g., T.Compose([CenterCrop(size), ToTensor()]).\n\nAttributes:\n h (int): Target height of the cropped image.\n w (int): Target width of the cropped image.\n\nMethods:\n __call__: Apply the center crop transformation to an input image.\n\nExamples:\n >>> transform = CenterCrop(640)\n >>> image = np.random.randint(0, 255, (1080, 1920, 3), dtype=np.uint8)\n >>> cropped_image = transform(image)\n >>> print(cropped_image.shape)\n (640, 640, 3)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "math", "random", "copy.deepcopy", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "typing.Union", "cv2", "numpy", "torch", "PIL.Image", "torch.nn.functional", "ultralytics.data.utils.polygons2masks", "ultralytics.data.utils.polygons2masks_overlap", "ultralytics.utils.LOGGER", "ultralytics.utils.IterableSimpleNamespace", "ultralytics.utils.colorstr", "ultralytics.utils.checks.check_version", "ultralytics.utils.instance.Instances", "ultralytics.utils.metrics.bbox_ioa", "ultralytics.utils.ops.segment2box", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xyxyxyxy2xywhr", "ultralytics.utils.torch_utils.TORCHVISION_0_10", "ultralytics.utils.torch_utils.TORCHVISION_0_11", "ultralytics.utils.torch_utils.TORCHVISION_0_13", "torchvision.transforms", "torchvision.transforms", "os", "albumentations" ], "chunk_id": "class_CenterCrop_34bb7dfb" }, { "content": "class ToTensor:\n \"\"\"\n Convert an image from a numpy array to a PyTorch tensor.\n\n This class is designed to be part of a transformation pipeline, e.g., T.Compose([LetterBox(size), ToTensor()]).\n\n Attributes:\n half (bool): If True, converts the image to half precision (float16).\n\n Methods:\n __call__: Apply the tensor conversion to an input image.\n\n Examples:\n >>> transform = ToTensor(half=True)\n >>> img = np.random.randint(0, 255, (640, 640, 3), dtype=np.uint8)\n >>> tensor_img = transform(img)\n >>> print(tensor_img.shape, tensor_img.dtype)\n torch.Size([3, 640, 640]) torch.float16\n\n Notes:\n The input image is expected to be in BGR format with shape (H, W, C).\n The output tensor will be in RGB format with shape (C, H, W), normalized to [0, 1].\n \"\"\"\n\n def __init__(self, half: bool = False):\n \"\"\"\n Initialize the ToTensor object for converting images to PyTorch tensors.\n\n This class is designed to be used as part of a transformation pipeline for image preprocessing in the\n Ultralytics YOLO framework. It converts numpy arrays or PIL Images to PyTorch tensors, with an option\n for half-precision (float16) conversion.\n\n Args:\n half (bool): If True, converts the tensor to half precision (float16).\n\n Examples:\n >>> transform = ToTensor(half=True)\n >>> img = np.random.rand(640, 640, 3)\n >>> tensor_img = transform(img)\n >>> print(tensor_img.dtype)\n torch.float16\n \"\"\"\n super().__init__()\n self.half = half\n\n def __call__(self, im: np.ndarray) -> torch.Tensor:\n \"\"\"\n Transform an image from a numpy array to a PyTorch tensor.\n\n This method converts the input image from a numpy array to a PyTorch tensor, applying optional\n half-precision conversion and normalization. The image is transposed from HWC to CHW format and\n the color channels are reversed from BGR to RGB.\n\n Args:\n im (np.ndarray): Input image as a numpy array with shape (H, W, C) in RGB order.\n\n Returns:\n (torch.Tensor): The transformed image as a PyTorch tensor in float32 or float16, normalized\n to [0, 1] with shape (C, H, W) in RGB order.\n\n Examples:\n >>> transform = ToTensor(half=True)\n >>> img = np.random.randint(0, 255, (640, 640, 3), dtype=np.uint8)\n >>> tensor_img = transform(img)\n >>> print(tensor_img.shape, tensor_img.dtype)\n torch.Size([3, 640, 640]) torch.float16\n \"\"\"\n im = np.ascontiguousarray(im.transpose((2, 0, 1))) # HWC to CHW -> contiguous\n im = torch.from_numpy(im) # to torch\n im = im.half() if self.half else im.float() # uint8 to fp16/32\n im /= 255.0 # 0-255 to 0.0-1.0\n return im", "chunk_type": "class", "name": "ToTensor", "file_path": "ultralytics\\ultralytics\\data\\augment.py", "start_line": 2908, "end_line": 2979, "start_col": 0, "end_col": 17, "parent_name": null, "docstring": "Convert an image from a numpy array to a PyTorch tensor.\n\nThis class is designed to be part of a transformation pipeline, e.g., T.Compose([LetterBox(size), ToTensor()]).\n\nAttributes:\n half (bool): If True, converts the image to half precision (float16).\n\nMethods:\n __call__: Apply the tensor conversion to an input image.\n\nExamples:\n >>> transform = ToTensor(half=True)\n >>> img = np.random.randint(0, 255, (640, 640, 3), dtype=np.uint8)\n >>> tensor_img = transform(img)\n >>> print(tensor_img.shape, tensor_img.dtype)\n torch.Size([3, 640, 640]) torch.float16\n\nNotes:\n The input image is expected to be in BGR format with shape (H, W, C).\n The output tensor will be in RGB format with shape (C, H, W), normalized to [0, 1].", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "math", "random", "copy.deepcopy", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "typing.Union", "cv2", "numpy", "torch", "PIL.Image", "torch.nn.functional", "ultralytics.data.utils.polygons2masks", "ultralytics.data.utils.polygons2masks_overlap", "ultralytics.utils.LOGGER", "ultralytics.utils.IterableSimpleNamespace", "ultralytics.utils.colorstr", "ultralytics.utils.checks.check_version", "ultralytics.utils.instance.Instances", "ultralytics.utils.metrics.bbox_ioa", "ultralytics.utils.ops.segment2box", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xyxyxyxy2xywhr", "ultralytics.utils.torch_utils.TORCHVISION_0_10", "ultralytics.utils.torch_utils.TORCHVISION_0_11", "ultralytics.utils.torch_utils.TORCHVISION_0_13", "torchvision.transforms", "torchvision.transforms", "os", "albumentations" ], "chunk_id": "class_ToTensor_55a8ebac" }, { "content": "import glob", "chunk_type": "import", "name": "glob", "file_path": "ultralytics\\ultralytics\\data\\base.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 11, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_glob_26ce03f5" }, { "content": "import math", "chunk_type": "import", "name": "math", "file_path": "ultralytics\\ultralytics\\data\\base.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 11, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_math_4ab96fff" }, { "content": "import os", "chunk_type": "import", "name": "os", "file_path": "ultralytics\\ultralytics\\data\\base.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 9, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_os_5c6aefd1" }, { "content": "import random", "chunk_type": "import", "name": "random", "file_path": "ultralytics\\ultralytics\\data\\base.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 13, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_random_432a023f" }, { "content": "from copy import deepcopy", "chunk_type": "import", "name": "deepcopy", "file_path": "ultralytics\\ultralytics\\data\\base.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 25, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_deepcopy_8919b0ce" }, { "content": "from multiprocessing.pool import ThreadPool", "chunk_type": "import", "name": "ThreadPool", "file_path": "ultralytics\\ultralytics\\data\\base.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 43, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_ThreadPool_7fb2f694" }, { "content": "from pathlib import Path", "chunk_type": "import", "name": "Path", "file_path": "ultralytics\\ultralytics\\data\\base.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Path_a1d7a643" }, { "content": "from typing import Any, Dict, List, Optional, Tuple, Union", "chunk_type": "import", "name": "Any, Dict, List, Optional, Tuple, Union", "file_path": "ultralytics\\ultralytics\\data\\base.py", "start_line": 10, "end_line": 10, "start_col": 0, "end_col": 58, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Any, Dict, List, Optional, Tuple, Union_e0d2d66b" }, { "content": "import cv2", "chunk_type": "import", "name": "cv2", "file_path": "ultralytics\\ultralytics\\data\\base.py", "start_line": 12, "end_line": 12, "start_col": 0, "end_col": 10, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_cv2_325460e9" }, { "content": "import numpy as np", "chunk_type": "import", "name": "numpy", "file_path": "ultralytics\\ultralytics\\data\\base.py", "start_line": 13, "end_line": 13, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_numpy_14764043" }, { "content": "from torch.utils.data import Dataset", "chunk_type": "import", "name": "Dataset", "file_path": "ultralytics\\ultralytics\\data\\base.py", "start_line": 14, "end_line": 14, "start_col": 0, "end_col": 36, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Dataset_84da3680" }, { "content": "from ultralytics.data.utils import FORMATS_HELP_MSG, HELP_URL, IMG_FORMATS, check_file_speeds", "chunk_type": "import", "name": "FORMATS_HELP_MSG, HELP_URL, IMG_FORMATS, check_file_speeds", "file_path": "ultralytics\\ultralytics\\data\\base.py", "start_line": 16, "end_line": 16, "start_col": 0, "end_col": 93, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_FORMATS_HELP_MSG, HELP_URL, IMG_FORMATS, check_file_speeds_ae8e8ea1" }, { "content": "from ultralytics.utils import DEFAULT_CFG, LOCAL_RANK, LOGGER, NUM_THREADS, TQDM", "chunk_type": "import", "name": "DEFAULT_CFG, LOCAL_RANK, LOGGER, NUM_THREADS, TQDM", "file_path": "ultralytics\\ultralytics\\data\\base.py", "start_line": 17, "end_line": 17, "start_col": 0, "end_col": 80, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_DEFAULT_CFG, LOCAL_RANK, LOGGER, NUM_THREADS, TQDM_9bef5e22" }, { "content": "from ultralytics.utils.patches import imread", "chunk_type": "import", "name": "imread", "file_path": "ultralytics\\ultralytics\\data\\base.py", "start_line": 18, "end_line": 18, "start_col": 0, "end_col": 44, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_imread_4656660a" }, { "content": "class BaseDataset(Dataset):\n \"\"\"\n Base dataset class for loading and processing image data.\n\n This class provides core functionality for loading images, caching, and preparing data for training and inference\n in object detection tasks.\n\n Attributes:\n img_path (str): Path to the folder containing images.\n imgsz (int): Target image size for resizing.\n augment (bool): Whether to apply data augmentation.\n single_cls (bool): Whether to treat all objects as a single class.\n prefix (str): Prefix to print in log messages.\n fraction (float): Fraction of dataset to utilize.\n channels (int): Number of channels in the images (1 for grayscale, 3 for RGB).\n cv2_flag (int): OpenCV flag for reading images.\n im_files (List[str]): List of image file paths.\n labels (List[Dict]): List of label data dictionaries.\n ni (int): Number of images in the dataset.\n rect (bool): Whether to use rectangular training.\n batch_size (int): Size of batches.\n stride (int): Stride used in the model.\n pad (float): Padding value.\n buffer (list): Buffer for mosaic images.\n max_buffer_length (int): Maximum buffer size.\n ims (list): List of loaded images.\n im_hw0 (list): List of original image dimensions (h, w).\n im_hw (list): List of resized image dimensions (h, w).\n npy_files (List[Path]): List of numpy file paths.\n cache (str): Cache images to RAM or disk during training.\n transforms (callable): Image transformation function.\n batch_shapes (np.ndarray): Batch shapes for rectangular training.\n batch (np.ndarray): Batch index of each image.\n\n Methods:\n get_img_files: Read image files from the specified path.\n update_labels: Update labels to include only specified classes.\n load_image: Load an image from the dataset.\n cache_images: Cache images to memory or disk.\n cache_images_to_disk: Save an image as an *.npy file for faster loading.\n check_cache_disk: Check image caching requirements vs available disk space.\n check_cache_ram: Check image caching requirements vs available memory.\n set_rectangle: Set the shape of bounding boxes as rectangles.\n get_image_and_label: Get and return label information from the dataset.\n update_labels_info: Custom label format method to be implemented by subclasses.\n build_transforms: Build transformation pipeline to be implemented by subclasses.\n get_labels: Get labels method to be implemented by subclasses.\n \"\"\"\n\n def __init__(\n self,\n img_path: Union[str, List[str]],\n imgsz: int = 640,\n cache: Union[bool, str] = False,\n augment: bool = True,\n hyp: Dict[str, Any] = DEFAULT_CFG,\n prefix: str = \"\",\n rect: bool = False,\n batch_size: int = 16,\n stride: int = 32,\n pad: float = 0.5,\n single_cls: bool = False,\n classes: Optional[List[int]] = None,\n fraction: float = 1.0,\n channels: int = 3,\n ):\n \"\"\"\n Initialize BaseDataset with given configuration and options.\n\n Args:\n img_path (str | List[str]): Path to the folder containing images or list of image paths.\n imgsz (int): Image size for resizing.\n cache (bool | str): Cache images to RAM or disk during training.\n augment (bool): If True, data augmentation is applied.\n hyp (Dict[str, Any]): Hyperparameters to apply data augmentation.\n prefix (str): Prefix to print in log messages.\n rect (bool): If True, rectangular training is used.\n batch_size (int): Size of batches.\n stride (int): Stride used in the model.\n pad (float): Padding value.\n single_cls (bool): If True, single class training is used.\n classes (List[int], optional): List of included classes.\n fraction (float): Fraction of dataset to utilize.\n channels (int): Number of channels in the images (1 for grayscale, 3 for RGB).\n \"\"\"\n super().__init__()\n self.img_path = img_path\n self.imgsz = imgsz\n self.augment = augment\n self.single_cls = single_cls\n self.prefix = prefix\n self.fraction = fraction\n self.channels = channels\n self.cv2_flag = cv2.IMREAD_GRAYSCALE if channels == 1 else cv2.IMREAD_COLOR\n self.im_files = self.get_img_files(self.img_path)\n self.labels = self.get_labels()\n self.update_labels(include_class=classes) # single_cls and include_class\n self.ni = len(self.labels) # number of images\n self.rect = rect\n self.batch_size = batch_size\n self.stride = stride\n self.pad = pad\n if self.rect:\n assert self.batch_size is not None\n self.set_rectangle()\n\n # Buffer thread for mosaic images\n self.buffer = [] # buffer size = batch size\n self.max_buffer_length = min((self.ni, self.batch_size * 8, 1000)) if self.augment else 0\n\n # Cache images (options are cache = True, False, None, \"ram\", \"disk\")\n self.ims, self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni, [None] * self.ni\n self.npy_files = [Path(f).with_suffix(\".npy\") for f in self.im_files]\n self.cache = cache.lower() if isinstance(cache, str) else \"ram\" if cache is True else None\n if self.cache == \"ram\" and self.check_cache_ram():\n if hyp.deterministic:\n LOGGER.warning(\n \"cache='ram' may produce non-deterministic training results. \"\n \"Consider cache='disk' as a deterministic alternative if your disk space allows.\"\n )\n self.cache_images()\n elif self.cache == \"disk\" and self.check_cache_disk():\n self.cache_images()\n\n # Transforms\n self.transforms = self.build_transforms(hyp=hyp)\n\n def get_img_files(self, img_path: Union[str, List[str]]) -> List[str]:\n \"\"\"\n Read image files from the specified path.\n\n Args:\n img_path (str | List[str]): Path or list of paths to image directories or files.\n\n Returns:\n (List[str]): List of image file paths.\n\n Raises:\n FileNotFoundError: If no images are found or the path doesn't exist.\n \"\"\"\n try:\n f = [] # image files\n for p in img_path if isinstance(img_path, list) else [img_path]:\n p = Path(p) # os-agnostic\n if p.is_dir(): # dir\n f += glob.glob(str(p / \"**\" / \"*.*\"), recursive=True)\n # F = list(p.rglob('*.*')) # pathlib\n elif p.is_file(): # file\n with open(p, encoding=\"utf-8\") as t:\n t = t.read().strip().splitlines()\n parent = str(p.parent) + os.sep\n f += [x.replace(\"./\", parent) if x.startswith(\"./\") else x for x in t] # local to global path\n # F += [p.parent / x.lstrip(os.sep) for x in t] # local to global path (pathlib)\n else:\n raise FileNotFoundError(f\"{self.prefix}{p} does not exist\")\n im_files = sorted(x.replace(\"/\", os.sep) for x in f if x.rpartition(\".\")[-1].lower() in IMG_FORMATS)\n # self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib\n assert im_files, f\"{self.prefix}No images found in {img_path}. {FORMATS_HELP_MSG}\"\n except Exception as e:\n raise FileNotFoundError(f\"{self.prefix}Error loading data from {img_path}\\n{HELP_URL}\") from e\n if self.fraction < 1:\n im_files = im_files[: round(len(im_files) * self.fraction)] # retain a fraction of the dataset\n check_file_speeds(im_files, prefix=self.prefix) # check image read speeds\n return im_files\n\n def update_labels(self, include_class: Optional[List[int]]) -> None:\n \"\"\"\n Update labels to include only specified classes.\n\n Args:\n include_class (List[int], optional): List of classes to include. If None, all classes are included.\n \"\"\"\n include_class_array = np.array(include_class).reshape(1, -1)\n for i in range(len(self.labels)):\n if include_class is not None:\n cls = self.labels[i][\"cls\"]\n bboxes = self.labels[i][\"bboxes\"]\n segments = self.labels[i][\"segments\"]\n keypoints = self.labels[i][\"keypoints\"]\n j = (cls == include_class_array).any(1)\n self.labels[i][\"cls\"] = cls[j]\n self.labels[i][\"bboxes\"] = bboxes[j]\n if segments:\n self.labels[i][\"segments\"] = [segments[si] for si, idx in enumerate(j) if idx]\n if keypoints is not None:\n self.labels[i][\"keypoints\"] = keypoints[j]\n if self.single_cls:\n self.labels[i][\"cls\"][:, 0] = 0\n\n def load_image(self, i: int, rect_mode: bool = True) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int]]:\n \"\"\"\n Load an image from dataset index 'i'.\n\n Args:\n i (int): Index of the image to load.\n rect_mode (bool): Whether to use rectangular resizing.\n\n Returns:\n im (np.ndarray): Loaded image as a NumPy array.\n hw_original (Tuple[int, int]): Original image dimensions in (height, width) format.\n hw_resized (Tuple[int, int]): Resized image dimensions in (height, width) format.\n\n Raises:\n FileNotFoundError: If the image file is not found.\n \"\"\"\n im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i]\n if im is None: # not cached in RAM\n if fn.exists(): # load npy\n try:\n im = np.load(fn)\n except Exception as e:\n LOGGER.warning(f\"{self.prefix}Removing corrupt *.npy image file {fn} due to: {e}\")\n Path(fn).unlink(missing_ok=True)\n im = imread(f, flags=self.cv2_flag) # BGR\n else: # read image\n im = imread(f, flags=self.cv2_flag) # BGR\n if im is None:\n raise FileNotFoundError(f\"Image Not Found {f}\")\n\n h0, w0 = im.shape[:2] # orig hw\n if rect_mode: # resize long side to imgsz while maintaining aspect ratio\n r = self.imgsz / max(h0, w0) # ratio\n if r != 1: # if sizes are not equal\n w, h = (min(math.ceil(w0 * r), self.imgsz), min(math.ceil(h0 * r), self.imgsz))\n im = cv2.resize(im, (w, h), interpolation=cv2.INTER_LINEAR)\n elif not (h0 == w0 == self.imgsz): # resize by stretching image to square imgsz\n im = cv2.resize(im, (self.imgsz, self.imgsz), interpolation=cv2.INTER_LINEAR)\n if im.ndim == 2:\n im = im[..., None]\n\n # Add to buffer if training with augmentations\n if self.augment:\n self.ims[i], self.im_hw0[i], self.im_hw[i] = im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized\n self.buffer.append(i)\n if 1 < len(self.buffer) >= self.max_buffer_length: # prevent empty buffer\n j = self.buffer.pop(0)\n if self.cache != \"ram\":\n self.ims[j], self.im_hw0[j], self.im_hw[j] = None, None, None\n\n return im, (h0, w0), im.shape[:2]\n\n return self.ims[i], self.im_hw0[i], self.im_hw[i]\n\n def cache_images(self) -> None:\n \"\"\"Cache images to memory or disk for faster training.\"\"\"\n b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes\n fcn, storage = (self.cache_images_to_disk, \"Disk\") if self.cache == \"disk\" else (self.load_image, \"RAM\")\n with ThreadPool(NUM_THREADS) as pool:\n results = pool.imap(fcn, range(self.ni))\n pbar = TQDM(enumerate(results), total=self.ni, disable=LOCAL_RANK > 0)\n for i, x in pbar:\n if self.cache == \"disk\":\n b += self.npy_files[i].stat().st_size\n else: # 'ram'\n self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)\n b += self.ims[i].nbytes\n pbar.desc = f\"{self.prefix}Caching images ({b / gb:.1f}GB {storage})\"\n pbar.close()\n\n def cache_images_to_disk(self, i: int) -> None:\n \"\"\"Save an image as an *.npy file for faster loading.\"\"\"\n f = self.npy_files[i]\n if not f.exists():\n np.save(f.as_posix(), imread(self.im_files[i]), allow_pickle=False)\n\n def check_cache_disk(self, safety_margin: float = 0.5) -> bool:\n \"\"\"\n Check if there's enough disk space for caching images.\n\n Args:\n safety_margin (float): Safety margin factor for disk space calculation.\n\n Returns:\n (bool): True if there's enough disk space, False otherwise.\n \"\"\"\n import shutil\n\n b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes\n n = min(self.ni, 30) # extrapolate from 30 random images\n for _ in range(n):\n im_file = random.choice(self.im_files)\n im = imread(im_file)\n if im is None:\n continue\n b += im.nbytes\n if not os.access(Path(im_file).parent, os.W_OK):\n self.cache = None\n LOGGER.warning(f\"{self.prefix}Skipping caching images to disk, directory not writeable\")\n return False\n disk_required = b * self.ni / n * (1 + safety_margin) # bytes required to cache dataset to disk\n total, used, free = shutil.disk_usage(Path(self.im_files[0]).parent)\n if disk_required > free:\n self.cache = None\n LOGGER.warning(\n f\"{self.prefix}{disk_required / gb:.1f}GB disk space required, \"\n f\"with {int(safety_margin * 100)}% safety margin but only \"\n f\"{free / gb:.1f}/{total / gb:.1f}GB free, not caching images to disk\"\n )\n return False\n return True\n\n def check_cache_ram(self, safety_margin: float = 0.5) -> bool:\n \"\"\"\n Check if there's enough RAM for caching images.\n\n Args:\n safety_margin (float): Safety margin factor for RAM calculation.\n\n Returns:\n (bool): True if there's enough RAM, False otherwise.\n \"\"\"\n b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes\n n = min(self.ni, 30) # extrapolate from 30 random images\n for _ in range(n):\n im = imread(random.choice(self.im_files)) # sample image\n if im is None:\n continue\n ratio = self.imgsz / max(im.shape[0], im.shape[1]) # max(h, w) # ratio\n b += im.nbytes * ratio**2\n mem_required = b * self.ni / n * (1 + safety_margin) # GB required to cache dataset into RAM\n mem = __import__(\"psutil\").virtual_memory()\n if mem_required > mem.available:\n self.cache = None\n LOGGER.warning(\n f\"{self.prefix}{mem_required / gb:.1f}GB RAM required to cache images \"\n f\"with {int(safety_margin * 100)}% safety margin but only \"\n f\"{mem.available / gb:.1f}/{mem.total / gb:.1f}GB available, not caching images\"\n )\n return False\n return True\n\n def set_rectangle(self) -> None:\n \"\"\"Set the shape of bounding boxes for YOLO detections as rectangles.\"\"\"\n bi = np.floor(np.arange(self.ni) / self.batch_size).astype(int) # batch index\n nb = bi[-1] + 1 # number of batches\n\n s = np.array([x.pop(\"shape\") for x in self.labels]) # hw\n ar = s[:, 0] / s[:, 1] # aspect ratio\n irect = ar.argsort()\n self.im_files = [self.im_files[i] for i in irect]\n self.labels = [self.labels[i] for i in irect]\n ar = ar[irect]\n\n # Set training image shapes\n shapes = [[1, 1]] * nb\n for i in range(nb):\n ari = ar[bi == i]\n mini, maxi = ari.min(), ari.max()\n if maxi < 1:\n shapes[i] = [maxi, 1]\n elif mini > 1:\n shapes[i] = [1, 1 / mini]\n\n self.batch_shapes = np.ceil(np.array(shapes) * self.imgsz / self.stride + self.pad).astype(int) * self.stride\n self.batch = bi # batch index of image\n\n def __getitem__(self, index: int) -> Dict[str, Any]:\n \"\"\"Return transformed label information for given index.\"\"\"\n return self.transforms(self.get_image_and_label(index))\n\n def get_image_and_label(self, index: int) -> Dict[str, Any]:\n \"\"\"\n Get and return label information from the dataset.\n\n Args:\n index (int): Index of the image to retrieve.\n\n Returns:\n (Dict[str, Any]): Label dictionary with image and metadata.\n \"\"\"\n label = deepcopy(self.labels[index]) # requires deepcopy() https://github.com/ultralytics/ultralytics/pull/1948\n label.pop(\"shape\", None) # shape is for rect, remove it\n label[\"img\"], label[\"ori_shape\"], label[\"resized_shape\"] = self.load_image(index)\n label[\"ratio_pad\"] = (\n label[\"resized_shape\"][0] / label[\"ori_shape\"][0],\n label[\"resized_shape\"][1] / label[\"ori_shape\"][1],\n ) # for evaluation\n if self.rect:\n label[\"rect_shape\"] = self.batch_shapes[self.batch[index]]\n return self.update_labels_info(label)\n\n def __len__(self) -> int:\n \"\"\"Return the length of the labels list for the dataset.\"\"\"\n return len(self.labels)\n\n def update_labels_info(self, label: Dict[str, Any]) -> Dict[str, Any]:\n \"\"\"Custom your label format here.\"\"\"\n return label\n\n def build_transforms(self, hyp: Optional[Dict[str, Any]] = None):\n \"\"\"\n Users can customize augmentations here.\n\n Examples:\n >>> if self.augment:\n ... # Training transforms\n ... return Compose([])\n >>> else:\n ... # Val transforms\n ... return Compose([])\n \"\"\"\n raise NotImplementedError\n\n def get_labels(self) -> List[Dict[str, Any]]:\n \"\"\"\n Users can customize their own format here.\n\n Examples:\n Ensure output is a dictionary with the following keys:\n >>> dict(\n ... im_file=im_file,\n ... shape=shape, # format: (height, width)\n ... cls=cls,\n ... bboxes=bboxes, # xywh\n ... segments=segments, # xy\n ... keypoints=keypoints, # xy\n ... normalized=True, # or False\n ... bbox_format=\"xyxy\", # or xywh, ltwh\n ... )\n \"\"\"\n raise NotImplementedError", "chunk_type": "class", "name": "BaseDataset", "file_path": "ultralytics\\ultralytics\\data\\base.py", "start_line": 21, "end_line": 441, "start_col": 0, "end_col": 33, "parent_name": null, "docstring": "Base dataset class for loading and processing image data.\n\nThis class provides core functionality for loading images, caching, and preparing data for training and inference\nin object detection tasks.\n\nAttributes:\n img_path (str): Path to the folder containing images.\n imgsz (int): Target image size for resizing.\n augment (bool): Whether to apply data augmentation.\n single_cls (bool): Whether to treat all objects as a single class.\n prefix (str): Prefix to print in log messages.\n fraction (float): Fraction of dataset to utilize.\n channels (int): Number of channels in the images (1 for grayscale, 3 for RGB).\n cv2_flag (int): OpenCV flag for reading images.\n im_files (List[str]): List of image file paths.\n labels (List[Dict]): List of label data dictionaries.\n ni (int): Number of images in the dataset.\n rect (bool): Whether to use rectangular training.\n batch_size (int): Size of batches.\n stride (int): Stride used in the model.\n pad (float): Padding value.\n buffer (list): Buffer for mosaic images.\n max_buffer_length (int): Maximum buffer size.\n ims (list): List of loaded images.\n im_hw0 (list): List of original image dimensions (h, w).\n im_hw (list): List of resized image dimensions (h, w).\n npy_files (List[Path]): List of numpy file paths.\n cache (str): Cache images to RAM or disk during training.\n transforms (callable): Image transformation function.\n batch_shapes (np.ndarray): Batch shapes for rectangular training.\n batch (np.ndarray): Batch index of each image.\n\nMethods:\n get_img_files: Read image files from the specified path.\n update_labels: Update labels to include only specified classes.\n load_image: Load an image from the dataset.\n cache_images: Cache images to memory or disk.\n cache_images_to_disk: Save an image as an *.npy file for faster loading.\n check_cache_disk: Check image caching requirements vs available disk space.\n check_cache_ram: Check image caching requirements vs available memory.\n set_rectangle: Set the shape of bounding boxes as rectangles.\n get_image_and_label: Get and return label information from the dataset.\n update_labels_info: Custom label format method to be implemented by subclasses.\n build_transforms: Build transformation pipeline to be implemented by subclasses.\n get_labels: Get labels method to be implemented by subclasses.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "glob", "math", "os", "random", "copy.deepcopy", "multiprocessing.pool.ThreadPool", "pathlib.Path", "typing.Any", "typing.Dict", "typing.List", "typing.Optional", "typing.Tuple", "typing.Union", "cv2", "numpy", "torch.utils.data.Dataset", "ultralytics.data.utils.FORMATS_HELP_MSG", "ultralytics.data.utils.HELP_URL", "ultralytics.data.utils.IMG_FORMATS", "ultralytics.data.utils.check_file_speeds", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.LOCAL_RANK", "ultralytics.utils.LOGGER", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.TQDM", "ultralytics.utils.patches.imread", "shutil", "Dataset" ], "chunk_id": "class_BaseDataset_fd713299" }, { "content": "import os", "chunk_type": "import", "name": "os", "file_path": "ultralytics\\ultralytics\\data\\build.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 9, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_os_301c709c" }, { "content": "import random", "chunk_type": "import", "name": "random", "file_path": "ultralytics\\ultralytics\\data\\build.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 13, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_random_0803aebf" }, { "content": "from pathlib import Path", "chunk_type": "import", "name": "Path", "file_path": "ultralytics\\ultralytics\\data\\build.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Path_b2b72e90" }, { "content": "from typing import Any, Dict, Iterator", "chunk_type": "import", "name": "Any, Dict, Iterator", "file_path": "ultralytics\\ultralytics\\data\\build.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 38, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Any, Dict, Iterator_85dc8a9f" }, { "content": "import numpy as np", "chunk_type": "import", "name": "numpy", "file_path": "ultralytics\\ultralytics\\data\\build.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_numpy_dbba1e73" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\ultralytics\\data\\build.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_86648396" }, { "content": "from PIL import Image", "chunk_type": "import", "name": "Image", "file_path": "ultralytics\\ultralytics\\data\\build.py", "start_line": 10, "end_line": 10, "start_col": 0, "end_col": 21, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Image_4c4d2132" }, { "content": "from torch.utils.data import dataloader, distributed", "chunk_type": "import", "name": "dataloader, distributed", "file_path": "ultralytics\\ultralytics\\data\\build.py", "start_line": 11, "end_line": 11, "start_col": 0, "end_col": 52, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_dataloader, distributed_208a4858" }, { "content": "from ultralytics.cfg import IterableSimpleNamespace", "chunk_type": "import", "name": "IterableSimpleNamespace", "file_path": "ultralytics\\ultralytics\\data\\build.py", "start_line": 13, "end_line": 13, "start_col": 0, "end_col": 51, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_IterableSimpleNamespace_b6471caa" }, { "content": "from ultralytics.data.dataset import GroundingDataset, YOLODataset, YOLOMultiModalDataset", "chunk_type": "import", "name": "GroundingDataset, YOLODataset, YOLOMultiModalDataset", "file_path": "ultralytics\\ultralytics\\data\\build.py", "start_line": 14, "end_line": 14, "start_col": 0, "end_col": 89, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_GroundingDataset, YOLODataset, YOLOMultiModalDataset_ed9b15f8" }, { "content": "from ultralytics.data.loaders import (\n LOADERS,\n LoadImagesAndVideos,\n LoadPilAndNumpy,\n LoadScreenshots,\n LoadStreams,\n LoadTensor,\n SourceTypes,\n autocast_list,\n)", "chunk_type": "import", "name": "LOADERS, LoadImagesAndVideos, LoadPilAndNumpy, LoadScreenshots, LoadStreams, LoadTensor, SourceTypes, autocast_list", "file_path": "ultralytics\\ultralytics\\data\\build.py", "start_line": 15, "end_line": 24, "start_col": 0, "end_col": 1, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_LOADERS, LoadImagesAndVideos, LoadPilAndNumpy, LoadScreenshots, LoadStreams, LoadTensor, SourceTypes, autocast_list_9868a3c1" }, { "content": "from ultralytics.data.utils import IMG_FORMATS, PIN_MEMORY, VID_FORMATS", "chunk_type": "import", "name": "IMG_FORMATS, PIN_MEMORY, VID_FORMATS", "file_path": "ultralytics\\ultralytics\\data\\build.py", "start_line": 25, "end_line": 25, "start_col": 0, "end_col": 71, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_IMG_FORMATS, PIN_MEMORY, VID_FORMATS_2cc3b9c9" }, { "content": "from ultralytics.utils import RANK, colorstr", "chunk_type": "import", "name": "RANK, colorstr", "file_path": "ultralytics\\ultralytics\\data\\build.py", "start_line": 26, "end_line": 26, "start_col": 0, "end_col": 44, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_RANK, colorstr_d7bb1932" }, { "content": "from ultralytics.utils.checks import check_file", "chunk_type": "import", "name": "check_file", "file_path": "ultralytics\\ultralytics\\data\\build.py", "start_line": 27, "end_line": 27, "start_col": 0, "end_col": 47, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_check_file_83291c28" }, { "content": "class InfiniteDataLoader(dataloader.DataLoader):\n \"\"\"\n Dataloader that reuses workers for infinite iteration.\n\n This dataloader extends the PyTorch DataLoader to provide infinite recycling of workers, which improves efficiency\n for training loops that need to iterate through the dataset multiple times without recreating workers.\n\n Attributes:\n batch_sampler (_RepeatSampler): A sampler that repeats indefinitely.\n iterator (Iterator): The iterator from the parent DataLoader.\n\n Methods:\n __len__: Return the length of the batch sampler's sampler.\n __iter__: Create a sampler that repeats indefinitely.\n __del__: Ensure workers are properly terminated.\n reset: Reset the iterator, useful when modifying dataset settings during training.\n\n Examples:\n Create an infinite dataloader for training\n >>> dataset = YOLODataset(...)\n >>> dataloader = InfiniteDataLoader(dataset, batch_size=16, shuffle=True)\n >>> for batch in dataloader: # Infinite iteration\n >>> train_step(batch)\n \"\"\"\n\n def __init__(self, *args: Any, **kwargs: Any):\n \"\"\"Initialize the InfiniteDataLoader with the same arguments as DataLoader.\"\"\"\n super().__init__(*args, **kwargs)\n object.__setattr__(self, \"batch_sampler\", _RepeatSampler(self.batch_sampler))\n self.iterator = super().__iter__()\n\n def __len__(self) -> int:\n \"\"\"Return the length of the batch sampler's sampler.\"\"\"\n return len(self.batch_sampler.sampler)\n\n def __iter__(self) -> Iterator:\n \"\"\"Create an iterator that yields indefinitely from the underlying iterator.\"\"\"\n for _ in range(len(self)):\n yield next(self.iterator)\n\n def __del__(self):\n \"\"\"Ensure that workers are properly terminated when the dataloader is deleted.\"\"\"\n try:\n if not hasattr(self.iterator, \"_workers\"):\n return\n for w in self.iterator._workers: # force terminate\n if w.is_alive():\n w.terminate()\n self.iterator._shutdown_workers() # cleanup\n except Exception:\n pass\n\n def reset(self):\n \"\"\"Reset the iterator to allow modifications to the dataset during training.\"\"\"\n self.iterator = self._get_iterator()", "chunk_type": "class", "name": "InfiniteDataLoader", "file_path": "ultralytics\\ultralytics\\data\\build.py", "start_line": 30, "end_line": 84, "start_col": 0, "end_col": 44, "parent_name": null, "docstring": "Dataloader that reuses workers for infinite iteration.\n\nThis dataloader extends the PyTorch DataLoader to provide infinite recycling of workers, which improves efficiency\nfor training loops that need to iterate through the dataset multiple times without recreating workers.\n\nAttributes:\n batch_sampler (_RepeatSampler): A sampler that repeats indefinitely.\n iterator (Iterator): The iterator from the parent DataLoader.\n\nMethods:\n __len__: Return the length of the batch sampler's sampler.\n __iter__: Create a sampler that repeats indefinitely.\n __del__: Ensure workers are properly terminated.\n reset: Reset the iterator, useful when modifying dataset settings during training.\n\nExamples:\n Create an infinite dataloader for training\n >>> dataset = YOLODataset(...)\n >>> dataloader = InfiniteDataLoader(dataset, batch_size=16, shuffle=True)\n >>> for batch in dataloader: # Infinite iteration\n >>> train_step(batch)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "os", "random", "pathlib.Path", "typing.Any", "typing.Dict", "typing.Iterator", "numpy", "torch", "PIL.Image", "torch.utils.data.dataloader", "torch.utils.data.distributed", "ultralytics.cfg.IterableSimpleNamespace", "ultralytics.data.dataset.GroundingDataset", "ultralytics.data.dataset.YOLODataset", "ultralytics.data.dataset.YOLOMultiModalDataset", "ultralytics.data.loaders.LOADERS", "ultralytics.data.loaders.LoadImagesAndVideos", "ultralytics.data.loaders.LoadPilAndNumpy", "ultralytics.data.loaders.LoadScreenshots", "ultralytics.data.loaders.LoadStreams", "ultralytics.data.loaders.LoadTensor", "ultralytics.data.loaders.SourceTypes", "ultralytics.data.loaders.autocast_list", "ultralytics.data.utils.IMG_FORMATS", "ultralytics.data.utils.PIN_MEMORY", "ultralytics.data.utils.VID_FORMATS", "ultralytics.utils.RANK", "ultralytics.utils.colorstr", "ultralytics.utils.checks.check_file", "dataloader.DataLoader" ], "chunk_id": "class_InfiniteDataLoader_2cfa5867" }, { "content": "class _RepeatSampler:\n \"\"\"\n Sampler that repeats forever for infinite iteration.\n\n This sampler wraps another sampler and yields its contents indefinitely, allowing for infinite iteration\n over a dataset without recreating the sampler.\n\n Attributes:\n sampler (Dataset.sampler): The sampler to repeat.\n \"\"\"\n\n def __init__(self, sampler: Any):\n \"\"\"Initialize the _RepeatSampler with a sampler to repeat indefinitely.\"\"\"\n self.sampler = sampler\n\n def __iter__(self) -> Iterator:\n \"\"\"Iterate over the sampler indefinitely, yielding its contents.\"\"\"\n while True:\n yield from iter(self.sampler)", "chunk_type": "class", "name": "_RepeatSampler", "file_path": "ultralytics\\ultralytics\\data\\build.py", "start_line": 87, "end_line": 105, "start_col": 0, "end_col": 41, "parent_name": null, "docstring": "Sampler that repeats forever for infinite iteration.\n\nThis sampler wraps another sampler and yields its contents indefinitely, allowing for infinite iteration\nover a dataset without recreating the sampler.\n\nAttributes:\n sampler (Dataset.sampler): The sampler to repeat.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "os", "random", "pathlib.Path", "typing.Any", "typing.Dict", "typing.Iterator", "numpy", "torch", "PIL.Image", "torch.utils.data.dataloader", "torch.utils.data.distributed", "ultralytics.cfg.IterableSimpleNamespace", "ultralytics.data.dataset.GroundingDataset", "ultralytics.data.dataset.YOLODataset", "ultralytics.data.dataset.YOLOMultiModalDataset", "ultralytics.data.loaders.LOADERS", "ultralytics.data.loaders.LoadImagesAndVideos", "ultralytics.data.loaders.LoadPilAndNumpy", "ultralytics.data.loaders.LoadScreenshots", "ultralytics.data.loaders.LoadStreams", "ultralytics.data.loaders.LoadTensor", "ultralytics.data.loaders.SourceTypes", "ultralytics.data.loaders.autocast_list", "ultralytics.data.utils.IMG_FORMATS", "ultralytics.data.utils.PIN_MEMORY", "ultralytics.data.utils.VID_FORMATS", "ultralytics.utils.RANK", "ultralytics.utils.colorstr", "ultralytics.utils.checks.check_file" ], "chunk_id": "class__RepeatSampler_54c6a24a" }, { "content": "def seed_worker(worker_id: int): # noqa\n \"\"\"Set dataloader worker seed for reproducibility across worker processes.\"\"\"\n worker_seed = torch.initial_seed() % 2**32\n np.random.seed(worker_seed)\n random.seed(worker_seed)", "chunk_type": "function", "name": "seed_worker", "file_path": "ultralytics\\ultralytics\\data\\build.py", "start_line": 108, "end_line": 112, "start_col": 0, "end_col": 28, "parent_name": null, "docstring": "Set dataloader worker seed for reproducibility across worker processes.", "parameters": [ "worker_id: int" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "os", "random", "pathlib.Path", "typing.Any", "typing.Dict", "typing.Iterator", "numpy", "torch", "PIL.Image", "torch.utils.data.dataloader", "torch.utils.data.distributed", "ultralytics.cfg.IterableSimpleNamespace", "ultralytics.data.dataset.GroundingDataset", "ultralytics.data.dataset.YOLODataset", "ultralytics.data.dataset.YOLOMultiModalDataset", "ultralytics.data.loaders.LOADERS", "ultralytics.data.loaders.LoadImagesAndVideos", "ultralytics.data.loaders.LoadPilAndNumpy", "ultralytics.data.loaders.LoadScreenshots", "ultralytics.data.loaders.LoadStreams", "ultralytics.data.loaders.LoadTensor", "ultralytics.data.loaders.SourceTypes", "ultralytics.data.loaders.autocast_list", "ultralytics.data.utils.IMG_FORMATS", "ultralytics.data.utils.PIN_MEMORY", "ultralytics.data.utils.VID_FORMATS", "ultralytics.utils.RANK", "ultralytics.utils.colorstr", "ultralytics.utils.checks.check_file" ], "chunk_id": "function_seed_worker_a075f28e" }, { "content": "def build_yolo_dataset(\n cfg: IterableSimpleNamespace,\n img_path: str,\n batch: int,\n data: Dict[str, Any],\n mode: str = \"train\",\n rect: bool = False,\n stride: int = 32,\n multi_modal: bool = False,\n):\n \"\"\"Build and return a YOLO dataset based on configuration parameters.\"\"\"\n dataset = YOLOMultiModalDataset if multi_modal else YOLODataset\n return dataset(\n img_path=img_path,\n imgsz=cfg.imgsz,\n batch_size=batch,\n augment=mode == \"train\", # augmentation\n hyp=cfg, # TODO: probably add a get_hyps_from_cfg function\n rect=cfg.rect or rect, # rectangular batches\n cache=cfg.cache or None,\n single_cls=cfg.single_cls or False,\n stride=int(stride),\n pad=0.0 if mode == \"train\" else 0.5,\n prefix=colorstr(f\"{mode}: \"),\n task=cfg.task,\n classes=cfg.classes,\n data=data,\n fraction=cfg.fraction if mode == \"train\" else 1.0,\n )", "chunk_type": "function", "name": "build_yolo_dataset", "file_path": "ultralytics\\ultralytics\\data\\build.py", "start_line": 115, "end_line": 143, "start_col": 0, "end_col": 5, "parent_name": null, "docstring": "Build and return a YOLO dataset based on configuration parameters.", "parameters": [ "cfg: IterableSimpleNamespace", "img_path: str", "batch: int", "data: Dict[str, Any]", "mode: str", "rect: bool", "stride: int", "multi_modal: bool" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "os", "random", "pathlib.Path", "typing.Any", "typing.Dict", "typing.Iterator", "numpy", "torch", "PIL.Image", "torch.utils.data.dataloader", "torch.utils.data.distributed", "ultralytics.cfg.IterableSimpleNamespace", "ultralytics.data.dataset.GroundingDataset", "ultralytics.data.dataset.YOLODataset", "ultralytics.data.dataset.YOLOMultiModalDataset", "ultralytics.data.loaders.LOADERS", "ultralytics.data.loaders.LoadImagesAndVideos", "ultralytics.data.loaders.LoadPilAndNumpy", "ultralytics.data.loaders.LoadScreenshots", "ultralytics.data.loaders.LoadStreams", "ultralytics.data.loaders.LoadTensor", "ultralytics.data.loaders.SourceTypes", "ultralytics.data.loaders.autocast_list", "ultralytics.data.utils.IMG_FORMATS", "ultralytics.data.utils.PIN_MEMORY", "ultralytics.data.utils.VID_FORMATS", "ultralytics.utils.RANK", "ultralytics.utils.colorstr", "ultralytics.utils.checks.check_file" ], "chunk_id": "function_build_yolo_dataset_bd7fd70d" }, { "content": "def build_grounding(\n cfg: IterableSimpleNamespace,\n img_path: str,\n json_file: str,\n batch: int,\n mode: str = \"train\",\n rect: bool = False,\n stride: int = 32,\n max_samples: int = 80,\n):\n \"\"\"Build and return a GroundingDataset based on configuration parameters.\"\"\"\n return GroundingDataset(\n img_path=img_path,\n json_file=json_file,\n max_samples=max_samples,\n imgsz=cfg.imgsz,\n batch_size=batch,\n augment=mode == \"train\", # augmentation\n hyp=cfg, # TODO: probably add a get_hyps_from_cfg function\n rect=cfg.rect or rect, # rectangular batches\n cache=cfg.cache or None,\n single_cls=cfg.single_cls or False,\n stride=int(stride),\n pad=0.0 if mode == \"train\" else 0.5,\n prefix=colorstr(f\"{mode}: \"),\n task=cfg.task,\n classes=cfg.classes,\n fraction=cfg.fraction if mode == \"train\" else 1.0,\n )", "chunk_type": "function", "name": "build_grounding", "file_path": "ultralytics\\ultralytics\\data\\build.py", "start_line": 146, "end_line": 174, "start_col": 0, "end_col": 5, "parent_name": null, "docstring": "Build and return a GroundingDataset based on configuration parameters.", "parameters": [ "cfg: IterableSimpleNamespace", "img_path: str", "json_file: str", "batch: int", "mode: str", "rect: bool", "stride: int", "max_samples: int" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "os", "random", "pathlib.Path", "typing.Any", "typing.Dict", "typing.Iterator", "numpy", "torch", "PIL.Image", "torch.utils.data.dataloader", "torch.utils.data.distributed", "ultralytics.cfg.IterableSimpleNamespace", "ultralytics.data.dataset.GroundingDataset", "ultralytics.data.dataset.YOLODataset", "ultralytics.data.dataset.YOLOMultiModalDataset", "ultralytics.data.loaders.LOADERS", "ultralytics.data.loaders.LoadImagesAndVideos", "ultralytics.data.loaders.LoadPilAndNumpy", "ultralytics.data.loaders.LoadScreenshots", "ultralytics.data.loaders.LoadStreams", "ultralytics.data.loaders.LoadTensor", "ultralytics.data.loaders.SourceTypes", "ultralytics.data.loaders.autocast_list", "ultralytics.data.utils.IMG_FORMATS", "ultralytics.data.utils.PIN_MEMORY", "ultralytics.data.utils.VID_FORMATS", "ultralytics.utils.RANK", "ultralytics.utils.colorstr", "ultralytics.utils.checks.check_file" ], "chunk_id": "function_build_grounding_4d37b377" }, { "content": "def build_dataloader(dataset, batch: int, workers: int, shuffle: bool = True, rank: int = -1, drop_last: bool = False):\n \"\"\"\n Create and return an InfiniteDataLoader or DataLoader for training or validation.\n\n Args:\n dataset (Dataset): Dataset to load data from.\n batch (int): Batch size for the dataloader.\n workers (int): Number of worker threads for loading data.\n shuffle (bool, optional): Whether to shuffle the dataset.\n rank (int, optional): Process rank in distributed training. -1 for single-GPU training.\n drop_last (bool, optional): Whether to drop the last incomplete batch.\n\n Returns:\n (InfiniteDataLoader): A dataloader that can be used for training or validation.\n\n Examples:\n Create a dataloader for training\n >>> dataset = YOLODataset(...)\n >>> dataloader = build_dataloader(dataset, batch=16, workers=4, shuffle=True)\n \"\"\"\n batch = min(batch, len(dataset))\n nd = torch.cuda.device_count() # number of CUDA devices\n nw = min(os.cpu_count() // max(nd, 1), workers) # number of workers\n sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)\n generator = torch.Generator()\n generator.manual_seed(6148914691236517205 + RANK)\n return InfiniteDataLoader(\n dataset=dataset,\n batch_size=batch,\n shuffle=shuffle and sampler is None,\n num_workers=nw,\n sampler=sampler,\n pin_memory=PIN_MEMORY,\n collate_fn=getattr(dataset, \"collate_fn\", None),\n worker_init_fn=seed_worker,\n generator=generator,\n drop_last=drop_last,\n )", "chunk_type": "function", "name": "build_dataloader", "file_path": "ultralytics\\ultralytics\\data\\build.py", "start_line": 177, "end_line": 214, "start_col": 0, "end_col": 5, "parent_name": null, "docstring": "Create and return an InfiniteDataLoader or DataLoader for training or validation.\n\nArgs:\n dataset (Dataset): Dataset to load data from.\n batch (int): Batch size for the dataloader.\n workers (int): Number of worker threads for loading data.\n shuffle (bool, optional): Whether to shuffle the dataset.\n rank (int, optional): Process rank in distributed training. -1 for single-GPU training.\n drop_last (bool, optional): Whether to drop the last incomplete batch.\n\nReturns:\n (InfiniteDataLoader): A dataloader that can be used for training or validation.\n\nExamples:\n Create a dataloader for training\n >>> dataset = YOLODataset(...)\n >>> dataloader = build_dataloader(dataset, batch=16, workers=4, shuffle=True)", "parameters": [ "dataset", "batch: int", "workers: int", "shuffle: bool", "rank: int", "drop_last: bool" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "os", "random", "pathlib.Path", "typing.Any", "typing.Dict", "typing.Iterator", "numpy", "torch", "PIL.Image", "torch.utils.data.dataloader", "torch.utils.data.distributed", "ultralytics.cfg.IterableSimpleNamespace", "ultralytics.data.dataset.GroundingDataset", "ultralytics.data.dataset.YOLODataset", "ultralytics.data.dataset.YOLOMultiModalDataset", "ultralytics.data.loaders.LOADERS", "ultralytics.data.loaders.LoadImagesAndVideos", "ultralytics.data.loaders.LoadPilAndNumpy", "ultralytics.data.loaders.LoadScreenshots", "ultralytics.data.loaders.LoadStreams", "ultralytics.data.loaders.LoadTensor", "ultralytics.data.loaders.SourceTypes", "ultralytics.data.loaders.autocast_list", "ultralytics.data.utils.IMG_FORMATS", "ultralytics.data.utils.PIN_MEMORY", "ultralytics.data.utils.VID_FORMATS", "ultralytics.utils.RANK", "ultralytics.utils.colorstr", "ultralytics.utils.checks.check_file" ], "chunk_id": "function_build_dataloader_c4705f8e" }, { "content": "def check_source(source):\n \"\"\"\n Check the type of input source and return corresponding flag values.\n\n Args:\n source (str | int | Path | list | tuple | np.ndarray | PIL.Image | torch.Tensor): The input source to check.\n\n Returns:\n source (str | int | Path | list | tuple | np.ndarray | PIL.Image | torch.Tensor): The processed source.\n webcam (bool): Whether the source is a webcam.\n screenshot (bool): Whether the source is a screenshot.\n from_img (bool): Whether the source is an image or list of images.\n in_memory (bool): Whether the source is an in-memory object.\n tensor (bool): Whether the source is a torch.Tensor.\n\n Examples:\n Check a file path source\n >>> source, webcam, screenshot, from_img, in_memory, tensor = check_source(\"image.jpg\")\n\n Check a webcam source\n >>> source, webcam, screenshot, from_img, in_memory, tensor = check_source(0)\n \"\"\"\n webcam, screenshot, from_img, in_memory, tensor = False, False, False, False, False\n if isinstance(source, (str, int, Path)): # int for local usb camera\n source = str(source)\n source_lower = source.lower()\n is_file = source_lower.rpartition(\".\")[-1] in (IMG_FORMATS | VID_FORMATS)\n is_url = source_lower.startswith((\"https://\", \"http://\", \"rtsp://\", \"rtmp://\", \"tcp://\"))\n webcam = source.isnumeric() or source.endswith(\".streams\") or (is_url and not is_file)\n screenshot = source_lower == \"screen\"\n if is_url and is_file:\n source = check_file(source) # download\n elif isinstance(source, LOADERS):\n in_memory = True\n elif isinstance(source, (list, tuple)):\n source = autocast_list(source) # convert all list elements to PIL or np arrays\n from_img = True\n elif isinstance(source, (Image.Image, np.ndarray)):\n from_img = True\n elif isinstance(source, torch.Tensor):\n tensor = True\n else:\n raise TypeError(\"Unsupported image type. For supported types see https://docs.ultralytics.com/modes/predict\")\n\n return source, webcam, screenshot, from_img, in_memory, tensor", "chunk_type": "function", "name": "check_source", "file_path": "ultralytics\\ultralytics\\data\\build.py", "start_line": 217, "end_line": 261, "start_col": 0, "end_col": 66, "parent_name": null, "docstring": "Check the type of input source and return corresponding flag values.\n\nArgs:\n source (str | int | Path | list | tuple | np.ndarray | PIL.Image | torch.Tensor): The input source to check.\n\nReturns:\n source (str | int | Path | list | tuple | np.ndarray | PIL.Image | torch.Tensor): The processed source.\n webcam (bool): Whether the source is a webcam.\n screenshot (bool): Whether the source is a screenshot.\n from_img (bool): Whether the source is an image or list of images.\n in_memory (bool): Whether the source is an in-memory object.\n tensor (bool): Whether the source is a torch.Tensor.\n\nExamples:\n Check a file path source\n >>> source, webcam, screenshot, from_img, in_memory, tensor = check_source(\"image.jpg\")\n\n Check a webcam source\n >>> source, webcam, screenshot, from_img, in_memory, tensor = check_source(0)", "parameters": [ "source" ], "return_type": null, "decorators": [], "complexity_score": 7, "dependencies": [ "os", "random", "pathlib.Path", "typing.Any", "typing.Dict", "typing.Iterator", "numpy", "torch", "PIL.Image", "torch.utils.data.dataloader", "torch.utils.data.distributed", "ultralytics.cfg.IterableSimpleNamespace", "ultralytics.data.dataset.GroundingDataset", "ultralytics.data.dataset.YOLODataset", "ultralytics.data.dataset.YOLOMultiModalDataset", "ultralytics.data.loaders.LOADERS", "ultralytics.data.loaders.LoadImagesAndVideos", "ultralytics.data.loaders.LoadPilAndNumpy", "ultralytics.data.loaders.LoadScreenshots", "ultralytics.data.loaders.LoadStreams", "ultralytics.data.loaders.LoadTensor", "ultralytics.data.loaders.SourceTypes", "ultralytics.data.loaders.autocast_list", "ultralytics.data.utils.IMG_FORMATS", "ultralytics.data.utils.PIN_MEMORY", "ultralytics.data.utils.VID_FORMATS", "ultralytics.utils.RANK", "ultralytics.utils.colorstr", "ultralytics.utils.checks.check_file" ], "chunk_id": "function_check_source_aa8b3868" }, { "content": "def load_inference_source(source=None, batch: int = 1, vid_stride: int = 1, buffer: bool = False, channels: int = 3):\n \"\"\"\n Load an inference source for object detection and apply necessary transformations.\n\n Args:\n source (str | Path | torch.Tensor | PIL.Image | np.ndarray, optional): The input source for inference.\n batch (int, optional): Batch size for dataloaders.\n vid_stride (int, optional): The frame interval for video sources.\n buffer (bool, optional): Whether stream frames will be buffered.\n channels (int, optional): The number of input channels for the model.\n\n Returns:\n (Dataset): A dataset object for the specified input source with attached source_type attribute.\n\n Examples:\n Load an image source for inference\n >>> dataset = load_inference_source(\"image.jpg\", batch=1)\n\n Load a video stream source\n >>> dataset = load_inference_source(\"rtsp://example.com/stream\", vid_stride=2)\n \"\"\"\n source, stream, screenshot, from_img, in_memory, tensor = check_source(source)\n source_type = source.source_type if in_memory else SourceTypes(stream, screenshot, from_img, tensor)\n\n # Dataloader\n if tensor:\n dataset = LoadTensor(source)\n elif in_memory:\n dataset = source\n elif stream:\n dataset = LoadStreams(source, vid_stride=vid_stride, buffer=buffer, channels=channels)\n elif screenshot:\n dataset = LoadScreenshots(source, channels=channels)\n elif from_img:\n dataset = LoadPilAndNumpy(source, channels=channels)\n else:\n dataset = LoadImagesAndVideos(source, batch=batch, vid_stride=vid_stride, channels=channels)\n\n # Attach source types to the dataset\n setattr(dataset, \"source_type\", source_type)\n\n return dataset", "chunk_type": "function", "name": "load_inference_source", "file_path": "ultralytics\\ultralytics\\data\\build.py", "start_line": 264, "end_line": 305, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": "Load an inference source for object detection and apply necessary transformations.\n\nArgs:\n source (str | Path | torch.Tensor | PIL.Image | np.ndarray, optional): The input source for inference.\n batch (int, optional): Batch size for dataloaders.\n vid_stride (int, optional): The frame interval for video sources.\n buffer (bool, optional): Whether stream frames will be buffered.\n channels (int, optional): The number of input channels for the model.\n\nReturns:\n (Dataset): A dataset object for the specified input source with attached source_type attribute.\n\nExamples:\n Load an image source for inference\n >>> dataset = load_inference_source(\"image.jpg\", batch=1)\n\n Load a video stream source\n >>> dataset = load_inference_source(\"rtsp://example.com/stream\", vid_stride=2)", "parameters": [ "source", "batch: int", "vid_stride: int", "buffer: bool", "channels: int" ], "return_type": null, "decorators": [], "complexity_score": 6, "dependencies": [ "os", "random", "pathlib.Path", "typing.Any", "typing.Dict", "typing.Iterator", "numpy", "torch", "PIL.Image", "torch.utils.data.dataloader", "torch.utils.data.distributed", "ultralytics.cfg.IterableSimpleNamespace", "ultralytics.data.dataset.GroundingDataset", "ultralytics.data.dataset.YOLODataset", "ultralytics.data.dataset.YOLOMultiModalDataset", "ultralytics.data.loaders.LOADERS", "ultralytics.data.loaders.LoadImagesAndVideos", "ultralytics.data.loaders.LoadPilAndNumpy", "ultralytics.data.loaders.LoadScreenshots", "ultralytics.data.loaders.LoadStreams", "ultralytics.data.loaders.LoadTensor", "ultralytics.data.loaders.SourceTypes", "ultralytics.data.loaders.autocast_list", "ultralytics.data.utils.IMG_FORMATS", "ultralytics.data.utils.PIN_MEMORY", "ultralytics.data.utils.VID_FORMATS", "ultralytics.utils.RANK", "ultralytics.utils.colorstr", "ultralytics.utils.checks.check_file" ], "chunk_id": "function_load_inference_source_926d6403" }, { "content": "import json", "chunk_type": "import", "name": "json", "file_path": "ultralytics\\ultralytics\\data\\converter.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 11, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_json_69779ef9" }, { "content": "import random", "chunk_type": "import", "name": "random", "file_path": "ultralytics\\ultralytics\\data\\converter.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 13, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_random_465ed0a8" }, { "content": "import shutil", "chunk_type": "import", "name": "shutil", "file_path": "ultralytics\\ultralytics\\data\\converter.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 13, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_shutil_34461bf2" }, { "content": "from collections import defaultdict", "chunk_type": "import", "name": "defaultdict", "file_path": "ultralytics\\ultralytics\\data\\converter.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 35, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_defaultdict_f5e1e0db" }, { "content": "from concurrent.futures import ThreadPoolExecutor, as_completed", "chunk_type": "import", "name": "ThreadPoolExecutor, as_completed", "file_path": "ultralytics\\ultralytics\\data\\converter.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 63, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_ThreadPoolExecutor, as_completed_3b389084" }, { "content": "from pathlib import Path", "chunk_type": "import", "name": "Path", "file_path": "ultralytics\\ultralytics\\data\\converter.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Path_6bae5ad7" }, { "content": "from typing import List, Optional, Union", "chunk_type": "import", "name": "List, Optional, Union", "file_path": "ultralytics\\ultralytics\\data\\converter.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 40, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_List, Optional, Union_1d29f8c9" }, { "content": "import cv2", "chunk_type": "import", "name": "cv2", "file_path": "ultralytics\\ultralytics\\data\\converter.py", "start_line": 11, "end_line": 11, "start_col": 0, "end_col": 10, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_cv2_7937cc34" }, { "content": "import numpy as np", "chunk_type": "import", "name": "numpy", "file_path": "ultralytics\\ultralytics\\data\\converter.py", "start_line": 12, "end_line": 12, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_numpy_855f3823" }, { "content": "from PIL import Image", "chunk_type": "import", "name": "Image", "file_path": "ultralytics\\ultralytics\\data\\converter.py", "start_line": 13, "end_line": 13, "start_col": 0, "end_col": 21, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Image_f0cf54e9" }, { "content": "from ultralytics.utils import DATASETS_DIR, LOGGER, NUM_THREADS, TQDM", "chunk_type": "import", "name": "DATASETS_DIR, LOGGER, NUM_THREADS, TQDM", "file_path": "ultralytics\\ultralytics\\data\\converter.py", "start_line": 15, "end_line": 15, "start_col": 0, "end_col": 69, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_DATASETS_DIR, LOGGER, NUM_THREADS, TQDM_0d99cbcd" }, { "content": "from ultralytics.utils.downloads import download, zip_directory", "chunk_type": "import", "name": "download, zip_directory", "file_path": "ultralytics\\ultralytics\\data\\converter.py", "start_line": 16, "end_line": 16, "start_col": 0, "end_col": 63, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_download, zip_directory_1df7d27c" }, { "content": "from ultralytics.utils.files import increment_path", "chunk_type": "import", "name": "increment_path", "file_path": "ultralytics\\ultralytics\\data\\converter.py", "start_line": 17, "end_line": 17, "start_col": 0, "end_col": 50, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_increment_path_f59febf1" }, { "content": "def coco91_to_coco80_class() -> List[int]:\n \"\"\"\n Convert 91-index COCO class IDs to 80-index COCO class IDs.\n\n Returns:\n (List[int]): A list of 91 class IDs where the index represents the 80-index class ID and the value\n is the corresponding 91-index class ID.\n \"\"\"\n return [\n 0,\n 1,\n 2,\n 3,\n 4,\n 5,\n 6,\n 7,\n 8,\n 9,\n 10,\n None,\n 11,\n 12,\n 13,\n 14,\n 15,\n 16,\n 17,\n 18,\n 19,\n 20,\n 21,\n 22,\n 23,\n None,\n 24,\n 25,\n None,\n None,\n 26,\n 27,\n 28,\n 29,\n 30,\n 31,\n 32,\n 33,\n 34,\n 35,\n 36,\n 37,\n 38,\n 39,\n None,\n 40,\n 41,\n 42,\n 43,\n 44,\n 45,\n 46,\n 47,\n 48,\n 49,\n 50,\n 51,\n 52,\n 53,\n 54,\n 55,\n 56,\n 57,\n 58,\n 59,\n None,\n 60,\n None,\n None,\n 61,\n None,\n 62,\n 63,\n 64,\n 65,\n 66,\n 67,\n 68,\n 69,\n 70,\n 71,\n 72,\n None,\n 73,\n 74,\n 75,\n 76,\n 77,\n 78,\n 79,\n None,\n ]", "chunk_type": "function", "name": "coco91_to_coco80_class", "file_path": "ultralytics\\ultralytics\\data\\converter.py", "start_line": 20, "end_line": 120, "start_col": 0, "end_col": 5, "parent_name": null, "docstring": "Convert 91-index COCO class IDs to 80-index COCO class IDs.\n\nReturns:\n (List[int]): A list of 91 class IDs where the index represents the 80-index class ID and the value\n is the corresponding 91-index class ID.", "parameters": [], "return_type": "List[int]", "decorators": [], "complexity_score": 1, "dependencies": [ "json", "random", "shutil", "collections.defaultdict", "concurrent.futures.ThreadPoolExecutor", "concurrent.futures.as_completed", "pathlib.Path", "typing.List", "typing.Optional", "typing.Union", "cv2", "numpy", "PIL.Image", "ultralytics.utils.DATASETS_DIR", "ultralytics.utils.LOGGER", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.TQDM", "ultralytics.utils.downloads.download", "ultralytics.utils.downloads.zip_directory", "ultralytics.utils.files.increment_path", "ultralytics.SAM", "ultralytics.data.YOLODataset", "ultralytics.utils.ops.xywh2xyxy", "scipy.interpolate.interp1d", "ultralytics.data.utils.IMG_FORMATS" ], "chunk_id": "function_coco91_to_coco80_class_3c68c2e2" }, { "content": "def coco80_to_coco91_class() -> List[int]:\n r\"\"\"\n Convert 80-index (val2014) to 91-index (paper).\n\n Returns:\n (List[int]): A list of 80 class IDs where each value is the corresponding 91-index class ID.\n\n References:\n https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/\n\n Examples:\n >>> import numpy as np\n >>> a = np.loadtxt(\"data/coco.names\", dtype=\"str\", delimiter=\"\\n\")\n >>> b = np.loadtxt(\"data/coco_paper.names\", dtype=\"str\", delimiter=\"\\n\")\n\n Convert the darknet to COCO format\n >>> x1 = [list(a[i] == b).index(True) + 1 for i in range(80)]\n\n Convert the COCO to darknet format\n >>> x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)]\n \"\"\"\n return [\n 1,\n 2,\n 3,\n 4,\n 5,\n 6,\n 7,\n 8,\n 9,\n 10,\n 11,\n 13,\n 14,\n 15,\n 16,\n 17,\n 18,\n 19,\n 20,\n 21,\n 22,\n 23,\n 24,\n 25,\n 27,\n 28,\n 31,\n 32,\n 33,\n 34,\n 35,\n 36,\n 37,\n 38,\n 39,\n 40,\n 41,\n 42,\n 43,\n 44,\n 46,\n 47,\n 48,\n 49,\n 50,\n 51,\n 52,\n 53,\n 54,\n 55,\n 56,\n 57,\n 58,\n 59,\n 60,\n 61,\n 62,\n 63,\n 64,\n 65,\n 67,\n 70,\n 72,\n 73,\n 74,\n 75,\n 76,\n 77,\n 78,\n 79,\n 80,\n 81,\n 82,\n 84,\n 85,\n 86,\n 87,\n 88,\n 89,\n 90,\n ]", "chunk_type": "function", "name": "coco80_to_coco91_class", "file_path": "ultralytics\\ultralytics\\data\\converter.py", "start_line": 123, "end_line": 225, "start_col": 0, "end_col": 5, "parent_name": null, "docstring": "Convert 80-index (val2014) to 91-index (paper).\n\nReturns:\n (List[int]): A list of 80 class IDs where each value is the corresponding 91-index class ID.\n\nReferences:\n https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/\n\nExamples:\n >>> import numpy as np\n >>> a = np.loadtxt(\"data/coco.names\", dtype=\"str\", delimiter=\"\\n\")\n >>> b = np.loadtxt(\"data/coco_paper.names\", dtype=\"str\", delimiter=\"\\n\")\n\n Convert the darknet to COCO format\n >>> x1 = [list(a[i] == b).index(True) + 1 for i in range(80)]\n\n Convert the COCO to darknet format\n >>> x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)]", "parameters": [], "return_type": "List[int]", "decorators": [], "complexity_score": 1, "dependencies": [ "json", "random", "shutil", "collections.defaultdict", "concurrent.futures.ThreadPoolExecutor", "concurrent.futures.as_completed", "pathlib.Path", "typing.List", "typing.Optional", "typing.Union", "cv2", "numpy", "PIL.Image", "ultralytics.utils.DATASETS_DIR", "ultralytics.utils.LOGGER", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.TQDM", "ultralytics.utils.downloads.download", "ultralytics.utils.downloads.zip_directory", "ultralytics.utils.files.increment_path", "ultralytics.SAM", "ultralytics.data.YOLODataset", "ultralytics.utils.ops.xywh2xyxy", "scipy.interpolate.interp1d", "ultralytics.data.utils.IMG_FORMATS" ], "chunk_id": "function_coco80_to_coco91_class_3340830f" }, { "content": "def convert_coco(\n labels_dir: str = \"../coco/annotations/\",\n save_dir: str = \"coco_converted/\",\n use_segments: bool = False,\n use_keypoints: bool = False,\n cls91to80: bool = True,\n lvis: bool = False,\n):\n \"\"\"\n Convert COCO dataset annotations to a YOLO annotation format suitable for training YOLO models.\n\n Args:\n labels_dir (str, optional): Path to directory containing COCO dataset annotation files.\n save_dir (str, optional): Path to directory to save results to.\n use_segments (bool, optional): Whether to include segmentation masks in the output.\n use_keypoints (bool, optional): Whether to include keypoint annotations in the output.\n cls91to80 (bool, optional): Whether to map 91 COCO class IDs to the corresponding 80 COCO class IDs.\n lvis (bool, optional): Whether to convert data in lvis dataset way.\n\n Examples:\n >>> from ultralytics.data.converter import convert_coco\n\n Convert COCO annotations to YOLO format\n >>> convert_coco(\"coco/annotations/\", use_segments=True, use_keypoints=False, cls91to80=False)\n\n Convert LVIS annotations to YOLO format\n >>> convert_coco(\"lvis/annotations/\", use_segments=True, use_keypoints=False, cls91to80=False, lvis=True)\n \"\"\"\n # Create dataset directory\n save_dir = increment_path(save_dir) # increment if save directory already exists\n for p in save_dir / \"labels\", save_dir / \"images\":\n p.mkdir(parents=True, exist_ok=True) # make dir\n\n # Convert classes\n coco80 = coco91_to_coco80_class()\n\n # Import json\n for json_file in sorted(Path(labels_dir).resolve().glob(\"*.json\")):\n lname = \"\" if lvis else json_file.stem.replace(\"instances_\", \"\")\n fn = Path(save_dir) / \"labels\" / lname # folder name\n fn.mkdir(parents=True, exist_ok=True)\n if lvis:\n # NOTE: create folders for both train and val in advance,\n # since LVIS val set contains images from COCO 2017 train in addition to the COCO 2017 val split.\n (fn / \"train2017\").mkdir(parents=True, exist_ok=True)\n (fn / \"val2017\").mkdir(parents=True, exist_ok=True)\n with open(json_file, encoding=\"utf-8\") as f:\n data = json.load(f)\n\n # Create image dict\n images = {f\"{x['id']:d}\": x for x in data[\"images\"]}\n # Create image-annotations dict\n annotations = defaultdict(list)\n for ann in data[\"annotations\"]:\n annotations[ann[\"image_id\"]].append(ann)\n\n image_txt = []\n # Write labels file\n for img_id, anns in TQDM(annotations.items(), desc=f\"Annotations {json_file}\"):\n img = images[f\"{img_id:d}\"]\n h, w = img[\"height\"], img[\"width\"]\n f = str(Path(img[\"coco_url\"]).relative_to(\"http://images.cocodataset.org\")) if lvis else img[\"file_name\"]\n if lvis:\n image_txt.append(str(Path(\"./images\") / f))\n\n bboxes = []\n segments = []\n keypoints = []\n for ann in anns:\n if ann.get(\"iscrowd\", False):\n continue\n # The COCO box format is [top left x, top left y, width, height]\n box = np.array(ann[\"bbox\"], dtype=np.float64)\n box[:2] += box[2:] / 2 # xy top-left corner to center\n box[[0, 2]] /= w # normalize x\n box[[1, 3]] /= h # normalize y\n if box[2] <= 0 or box[3] <= 0: # if w <= 0 and h <= 0\n continue\n\n cls = coco80[ann[\"category_id\"] - 1] if cls91to80 else ann[\"category_id\"] - 1 # class\n box = [cls] + box.tolist()\n if box not in bboxes:\n bboxes.append(box)\n if use_segments and ann.get(\"segmentation\") is not None:\n if len(ann[\"segmentation\"]) == 0:\n segments.append([])\n continue\n elif len(ann[\"segmentation\"]) > 1:\n s = merge_multi_segment(ann[\"segmentation\"])\n s = (np.concatenate(s, axis=0) / np.array([w, h])).reshape(-1).tolist()\n else:\n s = [j for i in ann[\"segmentation\"] for j in i] # all segments concatenated\n s = (np.array(s).reshape(-1, 2) / np.array([w, h])).reshape(-1).tolist()\n s = [cls] + s\n segments.append(s)\n if use_keypoints and ann.get(\"keypoints\") is not None:\n keypoints.append(\n box + (np.array(ann[\"keypoints\"]).reshape(-1, 3) / np.array([w, h, 1])).reshape(-1).tolist()\n )\n\n # Write\n with open((fn / f).with_suffix(\".txt\"), \"a\", encoding=\"utf-8\") as file:\n for i in range(len(bboxes)):\n if use_keypoints:\n line = (*(keypoints[i]),) # cls, box, keypoints\n else:\n line = (\n *(segments[i] if use_segments and len(segments[i]) > 0 else bboxes[i]),\n ) # cls, box or segments\n file.write((\"%g \" * len(line)).rstrip() % line + \"\\n\")\n\n if lvis:\n filename = Path(save_dir) / json_file.name.replace(\"lvis_v1_\", \"\").replace(\".json\", \".txt\")\n with open(filename, \"a\", encoding=\"utf-8\") as f:\n f.writelines(f\"{line}\\n\" for line in image_txt)\n\n LOGGER.info(f\"{'LVIS' if lvis else 'COCO'} data converted successfully.\\nResults saved to {save_dir.resolve()}\")", "chunk_type": "function", "name": "convert_coco", "file_path": "ultralytics\\ultralytics\\data\\converter.py", "start_line": 228, "end_line": 344, "start_col": 0, "end_col": 116, "parent_name": null, "docstring": "Convert COCO dataset annotations to a YOLO annotation format suitable for training YOLO models.\n\nArgs:\n labels_dir (str, optional): Path to directory containing COCO dataset annotation files.\n save_dir (str, optional): Path to directory to save results to.\n use_segments (bool, optional): Whether to include segmentation masks in the output.\n use_keypoints (bool, optional): Whether to include keypoint annotations in the output.\n cls91to80 (bool, optional): Whether to map 91 COCO class IDs to the corresponding 80 COCO class IDs.\n lvis (bool, optional): Whether to convert data in lvis dataset way.\n\nExamples:\n >>> from ultralytics.data.converter import convert_coco\n\n Convert COCO annotations to YOLO format\n >>> convert_coco(\"coco/annotations/\", use_segments=True, use_keypoints=False, cls91to80=False)\n\n Convert LVIS annotations to YOLO format\n >>> convert_coco(\"lvis/annotations/\", use_segments=True, use_keypoints=False, cls91to80=False, lvis=True)", "parameters": [ "labels_dir: str", "save_dir: str", "use_segments: bool", "use_keypoints: bool", "cls91to80: bool", "lvis: bool" ], "return_type": null, "decorators": [], "complexity_score": 21, "dependencies": [ "json", "random", "shutil", "collections.defaultdict", "concurrent.futures.ThreadPoolExecutor", "concurrent.futures.as_completed", "pathlib.Path", "typing.List", "typing.Optional", "typing.Union", "cv2", "numpy", "PIL.Image", "ultralytics.utils.DATASETS_DIR", "ultralytics.utils.LOGGER", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.TQDM", "ultralytics.utils.downloads.download", "ultralytics.utils.downloads.zip_directory", "ultralytics.utils.files.increment_path", "ultralytics.SAM", "ultralytics.data.YOLODataset", "ultralytics.utils.ops.xywh2xyxy", "scipy.interpolate.interp1d", "ultralytics.data.utils.IMG_FORMATS" ], "chunk_id": "function_convert_coco_0f3cf17b" }, { "content": "def convert_segment_masks_to_yolo_seg(masks_dir: str, output_dir: str, classes: int):\n \"\"\"\n Convert a dataset of segmentation mask images to the YOLO segmentation format.\n\n This function takes the directory containing the binary format mask images and converts them into YOLO segmentation\n format. The converted masks are saved in the specified output directory.\n\n Args:\n masks_dir (str): The path to the directory where all mask images (png, jpg) are stored.\n output_dir (str): The path to the directory where the converted YOLO segmentation masks will be stored.\n classes (int): Total classes in the dataset i.e. for COCO classes=80\n\n Examples:\n >>> from ultralytics.data.converter import convert_segment_masks_to_yolo_seg\n\n The classes here is the total classes in the dataset, for COCO dataset we have 80 classes\n >>> convert_segment_masks_to_yolo_seg(\"path/to/masks_directory\", \"path/to/output/directory\", classes=80)\n\n Notes:\n The expected directory structure for the masks is:\n\n - masks\n ├─ mask_image_01.png or mask_image_01.jpg\n ├─ mask_image_02.png or mask_image_02.jpg\n ├─ mask_image_03.png or mask_image_03.jpg\n └─ mask_image_04.png or mask_image_04.jpg\n\n After execution, the labels will be organized in the following structure:\n\n - output_dir\n ├─ mask_yolo_01.txt\n ├─ mask_yolo_02.txt\n ├─ mask_yolo_03.txt\n └─ mask_yolo_04.txt\n \"\"\"\n pixel_to_class_mapping = {i + 1: i for i in range(classes)}\n for mask_path in Path(masks_dir).iterdir():\n if mask_path.suffix in {\".png\", \".jpg\"}:\n mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE) # Read the mask image in grayscale\n img_height, img_width = mask.shape # Get image dimensions\n LOGGER.info(f\"Processing {mask_path} imgsz = {img_height} x {img_width}\")\n\n unique_values = np.unique(mask) # Get unique pixel values representing different classes\n yolo_format_data = []\n\n for value in unique_values:\n if value == 0:\n continue # Skip background\n class_index = pixel_to_class_mapping.get(value, -1)\n if class_index == -1:\n LOGGER.warning(f\"Unknown class for pixel value {value} in file {mask_path}, skipping.\")\n continue\n\n # Create a binary mask for the current class and find contours\n contours, _ = cv2.findContours(\n (mask == value).astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE\n ) # Find contours\n\n for contour in contours:\n if len(contour) >= 3: # YOLO requires at least 3 points for a valid segmentation\n contour = contour.squeeze() # Remove single-dimensional entries\n yolo_format = [class_index]\n for point in contour:\n # Normalize the coordinates\n yolo_format.append(round(point[0] / img_width, 6)) # Rounding to 6 decimal places\n yolo_format.append(round(point[1] / img_height, 6))\n yolo_format_data.append(yolo_format)\n # Save Ultralytics YOLO format data to file\n output_path = Path(output_dir) / f\"{mask_path.stem}.txt\"\n with open(output_path, \"w\", encoding=\"utf-8\") as file:\n for item in yolo_format_data:\n line = \" \".join(map(str, item))\n file.write(line + \"\\n\")\n LOGGER.info(f\"Processed and stored at {output_path} imgsz = {img_height} x {img_width}\")", "chunk_type": "function", "name": "convert_segment_masks_to_yolo_seg", "file_path": "ultralytics\\ultralytics\\data\\converter.py", "start_line": 347, "end_line": 420, "start_col": 0, "end_col": 100, "parent_name": null, "docstring": "Convert a dataset of segmentation mask images to the YOLO segmentation format.\n\nThis function takes the directory containing the binary format mask images and converts them into YOLO segmentation\nformat. The converted masks are saved in the specified output directory.\n\nArgs:\n masks_dir (str): The path to the directory where all mask images (png, jpg) are stored.\n output_dir (str): The path to the directory where the converted YOLO segmentation masks will be stored.\n classes (int): Total classes in the dataset i.e. for COCO classes=80\n\nExamples:\n >>> from ultralytics.data.converter import convert_segment_masks_to_yolo_seg\n\n The classes here is the total classes in the dataset, for COCO dataset we have 80 classes\n >>> convert_segment_masks_to_yolo_seg(\"path/to/masks_directory\", \"path/to/output/directory\", classes=80)\n\nNotes:\n The expected directory structure for the masks is:\n\n - masks\n ├─ mask_image_01.png or mask_image_01.jpg\n ├─ mask_image_02.png or mask_image_02.jpg\n ├─ mask_image_03.png or mask_image_03.jpg\n └─ mask_image_04.png or mask_image_04.jpg\n\n After execution, the labels will be organized in the following structure:\n\n - output_dir\n ├─ mask_yolo_01.txt\n ├─ mask_yolo_02.txt\n ├─ mask_yolo_03.txt\n └─ mask_yolo_04.txt", "parameters": [ "masks_dir: str", "output_dir: str", "classes: int" ], "return_type": null, "decorators": [], "complexity_score": 11, "dependencies": [ "json", "random", "shutil", "collections.defaultdict", "concurrent.futures.ThreadPoolExecutor", "concurrent.futures.as_completed", "pathlib.Path", "typing.List", "typing.Optional", "typing.Union", "cv2", "numpy", "PIL.Image", "ultralytics.utils.DATASETS_DIR", "ultralytics.utils.LOGGER", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.TQDM", "ultralytics.utils.downloads.download", "ultralytics.utils.downloads.zip_directory", "ultralytics.utils.files.increment_path", "ultralytics.SAM", "ultralytics.data.YOLODataset", "ultralytics.utils.ops.xywh2xyxy", "scipy.interpolate.interp1d", "ultralytics.data.utils.IMG_FORMATS" ], "chunk_id": "function_convert_segment_masks_to_yolo_seg_2529320f" }, { "content": "def convert_dota_to_yolo_obb(dota_root_path: str):\n \"\"\"\n Convert DOTA dataset annotations to YOLO OBB (Oriented Bounding Box) format.\n\n The function processes images in the 'train' and 'val' folders of the DOTA dataset. For each image, it reads the\n associated label from the original labels directory and writes new labels in YOLO OBB format to a new directory.\n\n Args:\n dota_root_path (str): The root directory path of the DOTA dataset.\n\n Examples:\n >>> from ultralytics.data.converter import convert_dota_to_yolo_obb\n >>> convert_dota_to_yolo_obb(\"path/to/DOTA\")\n\n Notes:\n The directory structure assumed for the DOTA dataset:\n\n - DOTA\n ├─ images\n │ ├─ train\n │ └─ val\n └─ labels\n ├─ train_original\n └─ val_original\n\n After execution, the function will organize the labels into:\n\n - DOTA\n └─ labels\n ├─ train\n └─ val\n \"\"\"\n dota_root_path = Path(dota_root_path)\n\n # Class names to indices mapping\n class_mapping = {\n \"plane\": 0,\n \"ship\": 1,\n \"storage-tank\": 2,\n \"baseball-diamond\": 3,\n \"tennis-court\": 4,\n \"basketball-court\": 5,\n \"ground-track-field\": 6,\n \"harbor\": 7,\n \"bridge\": 8,\n \"large-vehicle\": 9,\n \"small-vehicle\": 10,\n \"helicopter\": 11,\n \"roundabout\": 12,\n \"soccer-ball-field\": 13,\n \"swimming-pool\": 14,\n \"container-crane\": 15,\n \"airport\": 16,\n \"helipad\": 17,\n }\n\n def convert_label(image_name: str, image_width: int, image_height: int, orig_label_dir: Path, save_dir: Path):\n \"\"\"Convert a single image's DOTA annotation to YOLO OBB format and save it to a specified directory.\"\"\"\n orig_label_path = orig_label_dir / f\"{image_name}.txt\"\n save_path = save_dir / f\"{image_name}.txt\"\n\n with orig_label_path.open(\"r\") as f, save_path.open(\"w\") as g:\n lines = f.readlines()\n for line in lines:\n parts = line.strip().split()\n if len(parts) < 9:\n continue\n class_name = parts[8]\n class_idx = class_mapping[class_name]\n coords = [float(p) for p in parts[:8]]\n normalized_coords = [\n coords[i] / image_width if i % 2 == 0 else coords[i] / image_height for i in range(8)\n ]\n formatted_coords = [f\"{coord:.6g}\" for coord in normalized_coords]\n g.write(f\"{class_idx} {' '.join(formatted_coords)}\\n\")\n\n for phase in {\"train\", \"val\"}:\n image_dir = dota_root_path / \"images\" / phase\n orig_label_dir = dota_root_path / \"labels\" / f\"{phase}_original\"\n save_dir = dota_root_path / \"labels\" / phase\n\n save_dir.mkdir(parents=True, exist_ok=True)\n\n image_paths = list(image_dir.iterdir())\n for image_path in TQDM(image_paths, desc=f\"Processing {phase} images\"):\n if image_path.suffix != \".png\":\n continue\n image_name_without_ext = image_path.stem\n img = cv2.imread(str(image_path))\n h, w = img.shape[:2]\n convert_label(image_name_without_ext, w, h, orig_label_dir, save_dir)", "chunk_type": "function", "name": "convert_dota_to_yolo_obb", "file_path": "ultralytics\\ultralytics\\data\\converter.py", "start_line": 423, "end_line": 513, "start_col": 0, "end_col": 81, "parent_name": null, "docstring": "Convert DOTA dataset annotations to YOLO OBB (Oriented Bounding Box) format.\n\nThe function processes images in the 'train' and 'val' folders of the DOTA dataset. For each image, it reads the\nassociated label from the original labels directory and writes new labels in YOLO OBB format to a new directory.\n\nArgs:\n dota_root_path (str): The root directory path of the DOTA dataset.\n\nExamples:\n >>> from ultralytics.data.converter import convert_dota_to_yolo_obb\n >>> convert_dota_to_yolo_obb(\"path/to/DOTA\")\n\nNotes:\n The directory structure assumed for the DOTA dataset:\n\n - DOTA\n ├─ images\n │ ├─ train\n │ └─ val\n └─ labels\n ├─ train_original\n └─ val_original\n\n After execution, the function will organize the labels into:\n\n - DOTA\n └─ labels\n ├─ train\n └─ val", "parameters": [ "dota_root_path: str" ], "return_type": null, "decorators": [], "complexity_score": 9, "dependencies": [ "json", "random", "shutil", "collections.defaultdict", "concurrent.futures.ThreadPoolExecutor", "concurrent.futures.as_completed", "pathlib.Path", "typing.List", "typing.Optional", "typing.Union", "cv2", "numpy", "PIL.Image", "ultralytics.utils.DATASETS_DIR", "ultralytics.utils.LOGGER", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.TQDM", "ultralytics.utils.downloads.download", "ultralytics.utils.downloads.zip_directory", "ultralytics.utils.files.increment_path", "ultralytics.SAM", "ultralytics.data.YOLODataset", "ultralytics.utils.ops.xywh2xyxy", "scipy.interpolate.interp1d", "ultralytics.data.utils.IMG_FORMATS" ], "chunk_id": "function_convert_dota_to_yolo_obb_b2efd03c" }, { "content": "def min_index(arr1: np.ndarray, arr2: np.ndarray):\n \"\"\"\n Find a pair of indexes with the shortest distance between two arrays of 2D points.\n\n Args:\n arr1 (np.ndarray): A NumPy array of shape (N, 2) representing N 2D points.\n arr2 (np.ndarray): A NumPy array of shape (M, 2) representing M 2D points.\n\n Returns:\n idx1 (int): Index of the point in arr1 with the shortest distance.\n idx2 (int): Index of the point in arr2 with the shortest distance.\n \"\"\"\n dis = ((arr1[:, None, :] - arr2[None, :, :]) ** 2).sum(-1)\n return np.unravel_index(np.argmin(dis, axis=None), dis.shape)", "chunk_type": "function", "name": "min_index", "file_path": "ultralytics\\ultralytics\\data\\converter.py", "start_line": 516, "end_line": 529, "start_col": 0, "end_col": 65, "parent_name": null, "docstring": "Find a pair of indexes with the shortest distance between two arrays of 2D points.\n\nArgs:\n arr1 (np.ndarray): A NumPy array of shape (N, 2) representing N 2D points.\n arr2 (np.ndarray): A NumPy array of shape (M, 2) representing M 2D points.\n\nReturns:\n idx1 (int): Index of the point in arr1 with the shortest distance.\n idx2 (int): Index of the point in arr2 with the shortest distance.", "parameters": [ "arr1: np.ndarray", "arr2: np.ndarray" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "json", "random", "shutil", "collections.defaultdict", "concurrent.futures.ThreadPoolExecutor", "concurrent.futures.as_completed", "pathlib.Path", "typing.List", "typing.Optional", "typing.Union", "cv2", "numpy", "PIL.Image", "ultralytics.utils.DATASETS_DIR", "ultralytics.utils.LOGGER", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.TQDM", "ultralytics.utils.downloads.download", "ultralytics.utils.downloads.zip_directory", "ultralytics.utils.files.increment_path", "ultralytics.SAM", "ultralytics.data.YOLODataset", "ultralytics.utils.ops.xywh2xyxy", "scipy.interpolate.interp1d", "ultralytics.data.utils.IMG_FORMATS" ], "chunk_id": "function_min_index_ad55947b" }, { "content": "def merge_multi_segment(segments: List[List]):\n \"\"\"\n Merge multiple segments into one list by connecting the coordinates with the minimum distance between each segment.\n\n This function connects these coordinates with a thin line to merge all segments into one.\n\n Args:\n segments (List[List]): Original segmentations in COCO's JSON file.\n Each element is a list of coordinates, like [segmentation1, segmentation2,...].\n\n Returns:\n s (List[np.ndarray]): A list of connected segments represented as NumPy arrays.\n \"\"\"\n s = []\n segments = [np.array(i).reshape(-1, 2) for i in segments]\n idx_list = [[] for _ in range(len(segments))]\n\n # Record the indexes with min distance between each segment\n for i in range(1, len(segments)):\n idx1, idx2 = min_index(segments[i - 1], segments[i])\n idx_list[i - 1].append(idx1)\n idx_list[i].append(idx2)\n\n # Use two round to connect all the segments\n for k in range(2):\n # Forward connection\n if k == 0:\n for i, idx in enumerate(idx_list):\n # Middle segments have two indexes, reverse the index of middle segments\n if len(idx) == 2 and idx[0] > idx[1]:\n idx = idx[::-1]\n segments[i] = segments[i][::-1, :]\n\n segments[i] = np.roll(segments[i], -idx[0], axis=0)\n segments[i] = np.concatenate([segments[i], segments[i][:1]])\n # Deal with the first segment and the last one\n if i in {0, len(idx_list) - 1}:\n s.append(segments[i])\n else:\n idx = [0, idx[1] - idx[0]]\n s.append(segments[i][idx[0] : idx[1] + 1])\n\n else:\n for i in range(len(idx_list) - 1, -1, -1):\n if i not in {0, len(idx_list) - 1}:\n idx = idx_list[i]\n nidx = abs(idx[1] - idx[0])\n s.append(segments[i][nidx:])\n return s", "chunk_type": "function", "name": "merge_multi_segment", "file_path": "ultralytics\\ultralytics\\data\\converter.py", "start_line": 532, "end_line": 580, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": "Merge multiple segments into one list by connecting the coordinates with the minimum distance between each segment.\n\nThis function connects these coordinates with a thin line to merge all segments into one.\n\nArgs:\n segments (List[List]): Original segmentations in COCO's JSON file.\n Each element is a list of coordinates, like [segmentation1, segmentation2,...].\n\nReturns:\n s (List[np.ndarray]): A list of connected segments represented as NumPy arrays.", "parameters": [ "segments: List[List]" ], "return_type": null, "decorators": [], "complexity_score": 11, "dependencies": [ "json", "random", "shutil", "collections.defaultdict", "concurrent.futures.ThreadPoolExecutor", "concurrent.futures.as_completed", "pathlib.Path", "typing.List", "typing.Optional", "typing.Union", "cv2", "numpy", "PIL.Image", "ultralytics.utils.DATASETS_DIR", "ultralytics.utils.LOGGER", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.TQDM", "ultralytics.utils.downloads.download", "ultralytics.utils.downloads.zip_directory", "ultralytics.utils.files.increment_path", "ultralytics.SAM", "ultralytics.data.YOLODataset", "ultralytics.utils.ops.xywh2xyxy", "scipy.interpolate.interp1d", "ultralytics.data.utils.IMG_FORMATS" ], "chunk_id": "function_merge_multi_segment_65f51352" }, { "content": "def yolo_bbox2segment(\n im_dir: Union[str, Path], save_dir: Optional[Union[str, Path]] = None, sam_model: str = \"sam_b.pt\", device=None\n):\n \"\"\"\n Convert existing object detection dataset (bounding boxes) to segmentation dataset or oriented bounding box (OBB) in\n YOLO format. Generate segmentation data using SAM auto-annotator as needed.\n\n Args:\n im_dir (str | Path): Path to image directory to convert.\n save_dir (str | Path, optional): Path to save the generated labels, labels will be saved\n into `labels-segment` in the same directory level of `im_dir` if save_dir is None.\n sam_model (str): Segmentation model to use for intermediate segmentation data.\n device (int | str, optional): The specific device to run SAM models.\n\n Notes:\n The input directory structure assumed for dataset:\n\n - im_dir\n ├─ 001.jpg\n ├─ ...\n └─ NNN.jpg\n - labels\n ├─ 001.txt\n ├─ ...\n └─ NNN.txt\n \"\"\"\n from ultralytics import SAM\n from ultralytics.data import YOLODataset\n from ultralytics.utils.ops import xywh2xyxy\n\n # NOTE: add placeholder to pass class index check\n dataset = YOLODataset(im_dir, data=dict(names=list(range(1000))))\n if len(dataset.labels[0][\"segments\"]) > 0: # if it's segment data\n LOGGER.info(\"Segmentation labels detected, no need to generate new ones!\")\n return\n\n LOGGER.info(\"Detection labels detected, generating segment labels by SAM model!\")\n sam_model = SAM(sam_model)\n for label in TQDM(dataset.labels, total=len(dataset.labels), desc=\"Generating segment labels\"):\n h, w = label[\"shape\"]\n boxes = label[\"bboxes\"]\n if len(boxes) == 0: # skip empty labels\n continue\n boxes[:, [0, 2]] *= w\n boxes[:, [1, 3]] *= h\n im = cv2.imread(label[\"im_file\"])\n sam_results = sam_model(im, bboxes=xywh2xyxy(boxes), verbose=False, save=False, device=device)\n label[\"segments\"] = sam_results[0].masks.xyn\n\n save_dir = Path(save_dir) if save_dir else Path(im_dir).parent / \"labels-segment\"\n save_dir.mkdir(parents=True, exist_ok=True)\n for label in dataset.labels:\n texts = []\n lb_name = Path(label[\"im_file\"]).with_suffix(\".txt\").name\n txt_file = save_dir / lb_name\n cls = label[\"cls\"]\n for i, s in enumerate(label[\"segments\"]):\n if len(s) == 0:\n continue\n line = (int(cls[i]), *s.reshape(-1))\n texts.append((\"%g \" * len(line)).rstrip() % line)\n with open(txt_file, \"a\", encoding=\"utf-8\") as f:\n f.writelines(text + \"\\n\" for text in texts)\n LOGGER.info(f\"Generated segment labels saved in {save_dir}\")", "chunk_type": "function", "name": "yolo_bbox2segment", "file_path": "ultralytics\\ultralytics\\data\\converter.py", "start_line": 583, "end_line": 646, "start_col": 0, "end_col": 64, "parent_name": null, "docstring": "Convert existing object detection dataset (bounding boxes) to segmentation dataset or oriented bounding box (OBB) in\nYOLO format. Generate segmentation data using SAM auto-annotator as needed.\n\nArgs:\n im_dir (str | Path): Path to image directory to convert.\n save_dir (str | Path, optional): Path to save the generated labels, labels will be saved\n into `labels-segment` in the same directory level of `im_dir` if save_dir is None.\n sam_model (str): Segmentation model to use for intermediate segmentation data.\n device (int | str, optional): The specific device to run SAM models.\n\nNotes:\n The input directory structure assumed for dataset:\n\n - im_dir\n ├─ 001.jpg\n ├─ ...\n └─ NNN.jpg\n - labels\n ├─ 001.txt\n ├─ ...\n └─ NNN.txt", "parameters": [ "im_dir: Union[str, Path]", "save_dir: Optional[Union[str, Path]]", "sam_model: str", "device" ], "return_type": null, "decorators": [], "complexity_score": 8, "dependencies": [ "json", "random", "shutil", "collections.defaultdict", "concurrent.futures.ThreadPoolExecutor", "concurrent.futures.as_completed", "pathlib.Path", "typing.List", "typing.Optional", "typing.Union", "cv2", "numpy", "PIL.Image", "ultralytics.utils.DATASETS_DIR", "ultralytics.utils.LOGGER", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.TQDM", "ultralytics.utils.downloads.download", "ultralytics.utils.downloads.zip_directory", "ultralytics.utils.files.increment_path", "ultralytics.SAM", "ultralytics.data.YOLODataset", "ultralytics.utils.ops.xywh2xyxy", "scipy.interpolate.interp1d", "ultralytics.data.utils.IMG_FORMATS" ], "chunk_id": "function_yolo_bbox2segment_4c90071d" }, { "content": "def create_synthetic_coco_dataset():\n \"\"\"\n Create a synthetic COCO dataset with random images based on filenames from label lists.\n\n This function downloads COCO labels, reads image filenames from label list files,\n creates synthetic images for train2017 and val2017 subsets, and organizes\n them in the COCO dataset structure. It uses multithreading to generate images efficiently.\n\n Examples:\n >>> from ultralytics.data.converter import create_synthetic_coco_dataset\n >>> create_synthetic_coco_dataset()\n\n Notes:\n - Requires internet connection to download label files.\n - Generates random RGB images of varying sizes (480x480 to 640x640 pixels).\n - Existing test2017 directory is removed as it's not needed.\n - Reads image filenames from train2017.txt and val2017.txt files.\n \"\"\"\n\n def create_synthetic_image(image_file: Path):\n \"\"\"Generate synthetic images with random sizes and colors for dataset augmentation or testing purposes.\"\"\"\n if not image_file.exists():\n size = (random.randint(480, 640), random.randint(480, 640))\n Image.new(\n \"RGB\",\n size=size,\n color=(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)),\n ).save(image_file)\n\n # Download labels\n dir = DATASETS_DIR / \"coco\"\n url = \"https://github.com/ultralytics/assets/releases/download/v0.0.0/\"\n label_zip = \"coco2017labels-segments.zip\"\n download([url + label_zip], dir=dir.parent)\n\n # Create synthetic images\n shutil.rmtree(dir / \"labels\" / \"test2017\", ignore_errors=True) # Remove test2017 directory as not needed\n with ThreadPoolExecutor(max_workers=NUM_THREADS) as executor:\n for subset in {\"train2017\", \"val2017\"}:\n subset_dir = dir / \"images\" / subset\n subset_dir.mkdir(parents=True, exist_ok=True)\n\n # Read image filenames from label list file\n label_list_file = dir / f\"{subset}.txt\"\n if label_list_file.exists():\n with open(label_list_file, encoding=\"utf-8\") as f:\n image_files = [dir / line.strip() for line in f]\n\n # Submit all tasks\n futures = [executor.submit(create_synthetic_image, image_file) for image_file in image_files]\n for _ in TQDM(as_completed(futures), total=len(futures), desc=f\"Generating images for {subset}\"):\n pass # The actual work is done in the background\n else:\n LOGGER.warning(f\"Labels file {label_list_file} does not exist. Skipping image creation for {subset}.\")\n\n LOGGER.info(\"Synthetic COCO dataset created successfully.\")", "chunk_type": "function", "name": "create_synthetic_coco_dataset", "file_path": "ultralytics\\ultralytics\\data\\converter.py", "start_line": 649, "end_line": 704, "start_col": 0, "end_col": 63, "parent_name": null, "docstring": "Create a synthetic COCO dataset with random images based on filenames from label lists.\n\nThis function downloads COCO labels, reads image filenames from label list files,\ncreates synthetic images for train2017 and val2017 subsets, and organizes\nthem in the COCO dataset structure. It uses multithreading to generate images efficiently.\n\nExamples:\n >>> from ultralytics.data.converter import create_synthetic_coco_dataset\n >>> create_synthetic_coco_dataset()\n\nNotes:\n - Requires internet connection to download label files.\n - Generates random RGB images of varying sizes (480x480 to 640x640 pixels).\n - Existing test2017 directory is removed as it's not needed.\n - Reads image filenames from train2017.txt and val2017.txt files.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 7, "dependencies": [ "json", "random", "shutil", "collections.defaultdict", "concurrent.futures.ThreadPoolExecutor", "concurrent.futures.as_completed", "pathlib.Path", "typing.List", "typing.Optional", "typing.Union", "cv2", "numpy", "PIL.Image", "ultralytics.utils.DATASETS_DIR", "ultralytics.utils.LOGGER", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.TQDM", "ultralytics.utils.downloads.download", "ultralytics.utils.downloads.zip_directory", "ultralytics.utils.files.increment_path", "ultralytics.SAM", "ultralytics.data.YOLODataset", "ultralytics.utils.ops.xywh2xyxy", "scipy.interpolate.interp1d", "ultralytics.data.utils.IMG_FORMATS" ], "chunk_id": "function_create_synthetic_coco_dataset_ef0c4047" }, { "content": "def convert_to_multispectral(path: Union[str, Path], n_channels: int = 10, replace: bool = False, zip: bool = False):\n \"\"\"\n Convert RGB images to multispectral images by interpolating across wavelength bands.\n\n This function takes RGB images and interpolates them to create multispectral images with a specified number\n of channels. It can process either a single image or a directory of images.\n\n Args:\n path (str | Path): Path to an image file or directory containing images to convert.\n n_channels (int): Number of spectral channels to generate in the output image.\n replace (bool): Whether to replace the original image file with the converted one.\n zip (bool): Whether to zip the converted images into a zip file.\n\n Examples:\n Convert a single image\n >>> convert_to_multispectral(\"path/to/image.jpg\", n_channels=10)\n\n Convert a dataset\n >>> convert_to_multispectral(\"coco8\", n_channels=10)\n \"\"\"\n from scipy.interpolate import interp1d\n\n from ultralytics.data.utils import IMG_FORMATS\n\n path = Path(path)\n if path.is_dir():\n # Process directory\n im_files = sum([list(path.rglob(f\"*.{ext}\")) for ext in (IMG_FORMATS - {\"tif\", \"tiff\"})], [])\n for im_path in im_files:\n try:\n convert_to_multispectral(im_path, n_channels)\n if replace:\n im_path.unlink()\n except Exception as e:\n LOGGER.info(f\"Error converting {im_path}: {e}\")\n\n if zip:\n zip_directory(path)\n else:\n # Process a single image\n output_path = path.with_suffix(\".tiff\")\n img = cv2.cvtColor(cv2.imread(str(path)), cv2.COLOR_BGR2RGB)\n\n # Interpolate all pixels at once\n rgb_wavelengths = np.array([650, 510, 475]) # R, G, B wavelengths (nm)\n target_wavelengths = np.linspace(450, 700, n_channels)\n f = interp1d(rgb_wavelengths.T, img, kind=\"linear\", bounds_error=False, fill_value=\"extrapolate\")\n multispectral = f(target_wavelengths)\n cv2.imwritemulti(str(output_path), np.clip(multispectral, 0, 255).astype(np.uint8).transpose(2, 0, 1))\n LOGGER.info(f\"Converted {output_path}\")", "chunk_type": "function", "name": "convert_to_multispectral", "file_path": "ultralytics\\ultralytics\\data\\converter.py", "start_line": 707, "end_line": 756, "start_col": 0, "end_col": 47, "parent_name": null, "docstring": "Convert RGB images to multispectral images by interpolating across wavelength bands.\n\nThis function takes RGB images and interpolates them to create multispectral images with a specified number\nof channels. It can process either a single image or a directory of images.\n\nArgs:\n path (str | Path): Path to an image file or directory containing images to convert.\n n_channels (int): Number of spectral channels to generate in the output image.\n replace (bool): Whether to replace the original image file with the converted one.\n zip (bool): Whether to zip the converted images into a zip file.\n\nExamples:\n Convert a single image\n >>> convert_to_multispectral(\"path/to/image.jpg\", n_channels=10)\n\n Convert a dataset\n >>> convert_to_multispectral(\"coco8\", n_channels=10)", "parameters": [ "path: Union[str, Path]", "n_channels: int", "replace: bool", "zip: bool" ], "return_type": null, "decorators": [], "complexity_score": 7, "dependencies": [ "json", "random", "shutil", "collections.defaultdict", "concurrent.futures.ThreadPoolExecutor", "concurrent.futures.as_completed", "pathlib.Path", "typing.List", "typing.Optional", "typing.Union", "cv2", "numpy", "PIL.Image", "ultralytics.utils.DATASETS_DIR", "ultralytics.utils.LOGGER", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.TQDM", "ultralytics.utils.downloads.download", "ultralytics.utils.downloads.zip_directory", "ultralytics.utils.files.increment_path", "ultralytics.SAM", "ultralytics.data.YOLODataset", "ultralytics.utils.ops.xywh2xyxy", "scipy.interpolate.interp1d", "ultralytics.data.utils.IMG_FORMATS" ], "chunk_id": "function_convert_to_multispectral_5f50acde" }, { "content": "import json", "chunk_type": "import", "name": "json", "file_path": "ultralytics\\ultralytics\\data\\dataset.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 11, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_json_4ae3858d" }, { "content": "from collections import defaultdict", "chunk_type": "import", "name": "defaultdict", "file_path": "ultralytics\\ultralytics\\data\\dataset.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 35, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_defaultdict_907f4be2" }, { "content": "from itertools import repeat", "chunk_type": "import", "name": "repeat", "file_path": "ultralytics\\ultralytics\\data\\dataset.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 28, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_repeat_aa38708c" }, { "content": "from multiprocessing.pool import ThreadPool", "chunk_type": "import", "name": "ThreadPool", "file_path": "ultralytics\\ultralytics\\data\\dataset.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 43, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_ThreadPool_2b3af042" }, { "content": "from pathlib import Path", "chunk_type": "import", "name": "Path", "file_path": "ultralytics\\ultralytics\\data\\dataset.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Path_846374c3" }, { "content": "from typing import Any, Dict, List, Optional, Tuple", "chunk_type": "import", "name": "Any, Dict, List, Optional, Tuple", "file_path": "ultralytics\\ultralytics\\data\\dataset.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 51, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Any, Dict, List, Optional, Tuple_0f07f6cf" }, { "content": "import cv2", "chunk_type": "import", "name": "cv2", "file_path": "ultralytics\\ultralytics\\data\\dataset.py", "start_line": 10, "end_line": 10, "start_col": 0, "end_col": 10, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_cv2_cdb8f6f7" }, { "content": "import numpy as np", "chunk_type": "import", "name": "numpy", "file_path": "ultralytics\\ultralytics\\data\\dataset.py", "start_line": 11, "end_line": 11, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_numpy_e02c452a" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\ultralytics\\data\\dataset.py", "start_line": 12, "end_line": 12, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_ff592f13" }, { "content": "from PIL import Image", "chunk_type": "import", "name": "Image", "file_path": "ultralytics\\ultralytics\\data\\dataset.py", "start_line": 13, "end_line": 13, "start_col": 0, "end_col": 21, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Image_7a02eb02" }, { "content": "from torch.utils.data import ConcatDataset", "chunk_type": "import", "name": "ConcatDataset", "file_path": "ultralytics\\ultralytics\\data\\dataset.py", "start_line": 14, "end_line": 14, "start_col": 0, "end_col": 42, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_ConcatDataset_7fd36388" }, { "content": "from ultralytics.utils import LOCAL_RANK, LOGGER, NUM_THREADS, TQDM, colorstr", "chunk_type": "import", "name": "LOCAL_RANK, LOGGER, NUM_THREADS, TQDM, colorstr", "file_path": "ultralytics\\ultralytics\\data\\dataset.py", "start_line": 16, "end_line": 16, "start_col": 0, "end_col": 77, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_LOCAL_RANK, LOGGER, NUM_THREADS, TQDM, colorstr_c0920eb5" }, { "content": "from ultralytics.utils.instance import Instances", "chunk_type": "import", "name": "Instances", "file_path": "ultralytics\\ultralytics\\data\\dataset.py", "start_line": 17, "end_line": 17, "start_col": 0, "end_col": 48, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Instances_b82b840b" }, { "content": "from ultralytics.utils.ops import resample_segments, segments2boxes", "chunk_type": "import", "name": "resample_segments, segments2boxes", "file_path": "ultralytics\\ultralytics\\data\\dataset.py", "start_line": 18, "end_line": 18, "start_col": 0, "end_col": 67, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_resample_segments, segments2boxes_749ae6c6" }, { "content": "from ultralytics.utils.torch_utils import TORCHVISION_0_18", "chunk_type": "import", "name": "TORCHVISION_0_18", "file_path": "ultralytics\\ultralytics\\data\\dataset.py", "start_line": 19, "end_line": 19, "start_col": 0, "end_col": 58, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_TORCHVISION_0_18_53354e03" }, { "content": "from .augment import (\n Compose,\n Format,\n LetterBox,\n RandomLoadText,\n classify_augmentations,\n classify_transforms,\n v8_transforms,\n)", "chunk_type": "import", "name": "Compose, Format, LetterBox, RandomLoadText, classify_augmentations, classify_transforms, v8_transforms", "file_path": "ultralytics\\ultralytics\\data\\dataset.py", "start_line": 21, "end_line": 29, "start_col": 0, "end_col": 1, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Compose, Format, LetterBox, RandomLoadText, classify_augmentations, classify_transforms, v8_transforms_0feefe1b" }, { "content": "from .base import BaseDataset", "chunk_type": "import", "name": "BaseDataset", "file_path": "ultralytics\\ultralytics\\data\\dataset.py", "start_line": 30, "end_line": 30, "start_col": 0, "end_col": 29, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_BaseDataset_c915cfb7" }, { "content": "from .converter import merge_multi_segment", "chunk_type": "import", "name": "merge_multi_segment", "file_path": "ultralytics\\ultralytics\\data\\dataset.py", "start_line": 31, "end_line": 31, "start_col": 0, "end_col": 42, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_merge_multi_segment_ee68b1a4" }, { "content": "from .utils import (\n HELP_URL,\n check_file_speeds,\n get_hash,\n img2label_paths,\n load_dataset_cache_file,\n save_dataset_cache_file,\n verify_image,\n verify_image_label,\n)", "chunk_type": "import", "name": "HELP_URL, check_file_speeds, get_hash, img2label_paths, load_dataset_cache_file, save_dataset_cache_file, verify_image, verify_image_label", "file_path": "ultralytics\\ultralytics\\data\\dataset.py", "start_line": 32, "end_line": 41, "start_col": 0, "end_col": 1, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_HELP_URL, check_file_speeds, get_hash, img2label_paths, load_dataset_cache_file, save_dataset_cache_file, verify_image, verify_image_label_a4edcc49" }, { "content": "DATASET_CACHE_VERSION = \"1.0.3\"", "chunk_type": "variable", "name": "DATASET_CACHE_VERSION", "file_path": "ultralytics\\ultralytics\\data\\dataset.py", "start_line": 44, "end_line": 44, "start_col": 0, "end_col": 31, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_DATASET_CACHE_VERSION_4e709a66" }, { "content": "class YOLODataset(BaseDataset):\n \"\"\"\n Dataset class for loading object detection and/or segmentation labels in YOLO format.\n\n This class supports loading data for object detection, segmentation, pose estimation, and oriented bounding box\n (OBB) tasks using the YOLO format.\n\n Attributes:\n use_segments (bool): Indicates if segmentation masks should be used.\n use_keypoints (bool): Indicates if keypoints should be used for pose estimation.\n use_obb (bool): Indicates if oriented bounding boxes should be used.\n data (dict): Dataset configuration dictionary.\n\n Methods:\n cache_labels: Cache dataset labels, check images and read shapes.\n get_labels: Return dictionary of labels for YOLO training.\n build_transforms: Build and append transforms to the list.\n close_mosaic: Set mosaic, copy_paste and mixup options to 0.0 and build transformations.\n update_labels_info: Update label format for different tasks.\n collate_fn: Collate data samples into batches.\n\n Examples:\n >>> dataset = YOLODataset(img_path=\"path/to/images\", data={\"names\": {0: \"person\"}}, task=\"detect\")\n >>> dataset.get_labels()\n \"\"\"\n\n def __init__(self, *args, data: Optional[Dict] = None, task: str = \"detect\", **kwargs):\n \"\"\"\n Initialize the YOLODataset.\n\n Args:\n data (dict, optional): Dataset configuration dictionary.\n task (str): Task type, one of 'detect', 'segment', 'pose', or 'obb'.\n *args (Any): Additional positional arguments for the parent class.\n **kwargs (Any): Additional keyword arguments for the parent class.\n \"\"\"\n self.use_segments = task == \"segment\"\n self.use_keypoints = task == \"pose\"\n self.use_obb = task == \"obb\"\n self.data = data\n assert not (self.use_segments and self.use_keypoints), \"Can not use both segments and keypoints.\"\n super().__init__(*args, channels=self.data[\"channels\"], **kwargs)\n\n def cache_labels(self, path: Path = Path(\"./labels.cache\")) -> Dict:\n \"\"\"\n Cache dataset labels, check images and read shapes.\n\n Args:\n path (Path): Path where to save the cache file.\n\n Returns:\n (dict): Dictionary containing cached labels and related information.\n \"\"\"\n x = {\"labels\": []}\n nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages\n desc = f\"{self.prefix}Scanning {path.parent / path.stem}...\"\n total = len(self.im_files)\n nkpt, ndim = self.data.get(\"kpt_shape\", (0, 0))\n if self.use_keypoints and (nkpt <= 0 or ndim not in {2, 3}):\n raise ValueError(\n \"'kpt_shape' in data.yaml missing or incorrect. Should be a list with [number of \"\n \"keypoints, number of dims (2 for x,y or 3 for x,y,visible)], i.e. 'kpt_shape: [17, 3]'\"\n )\n with ThreadPool(NUM_THREADS) as pool:\n results = pool.imap(\n func=verify_image_label,\n iterable=zip(\n self.im_files,\n self.label_files,\n repeat(self.prefix),\n repeat(self.use_keypoints),\n repeat(len(self.data[\"names\"])),\n repeat(nkpt),\n repeat(ndim),\n repeat(self.single_cls),\n ),\n )\n pbar = TQDM(results, desc=desc, total=total)\n for im_file, lb, shape, segments, keypoint, nm_f, nf_f, ne_f, nc_f, msg in pbar:\n nm += nm_f\n nf += nf_f\n ne += ne_f\n nc += nc_f\n if im_file:\n x[\"labels\"].append(\n {\n \"im_file\": im_file,\n \"shape\": shape,\n \"cls\": lb[:, 0:1], # n, 1\n \"bboxes\": lb[:, 1:], # n, 4\n \"segments\": segments,\n \"keypoints\": keypoint,\n \"normalized\": True,\n \"bbox_format\": \"xywh\",\n }\n )\n if msg:\n msgs.append(msg)\n pbar.desc = f\"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt\"\n pbar.close()\n\n if msgs:\n LOGGER.info(\"\\n\".join(msgs))\n if nf == 0:\n LOGGER.warning(f\"{self.prefix}No labels found in {path}. {HELP_URL}\")\n x[\"hash\"] = get_hash(self.label_files + self.im_files)\n x[\"results\"] = nf, nm, ne, nc, len(self.im_files)\n x[\"msgs\"] = msgs # warnings\n save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)\n return x\n\n def get_labels(self) -> List[Dict]:\n \"\"\"\n Return dictionary of labels for YOLO training.\n\n This method loads labels from disk or cache, verifies their integrity, and prepares them for training.\n\n Returns:\n (List[dict]): List of label dictionaries, each containing information about an image and its annotations.\n \"\"\"\n self.label_files = img2label_paths(self.im_files)\n cache_path = Path(self.label_files[0]).parent.with_suffix(\".cache\")\n try:\n cache, exists = load_dataset_cache_file(cache_path), True # attempt to load a *.cache file\n assert cache[\"version\"] == DATASET_CACHE_VERSION # matches current version\n assert cache[\"hash\"] == get_hash(self.label_files + self.im_files) # identical hash\n except (FileNotFoundError, AssertionError, AttributeError):\n cache, exists = self.cache_labels(cache_path), False # run cache ops\n\n # Display cache\n nf, nm, ne, nc, n = cache.pop(\"results\") # found, missing, empty, corrupt, total\n if exists and LOCAL_RANK in {-1, 0}:\n d = f\"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt\"\n TQDM(None, desc=self.prefix + d, total=n, initial=n) # display results\n if cache[\"msgs\"]:\n LOGGER.info(\"\\n\".join(cache[\"msgs\"])) # display warnings\n\n # Read cache\n [cache.pop(k) for k in (\"hash\", \"version\", \"msgs\")] # remove items\n labels = cache[\"labels\"]\n if not labels:\n raise RuntimeError(\n f\"No valid images found in {cache_path}. Images with incorrectly formatted labels are ignored. {HELP_URL}\"\n )\n self.im_files = [lb[\"im_file\"] for lb in labels] # update im_files\n\n # Check if the dataset is all boxes or all segments\n lengths = ((len(lb[\"cls\"]), len(lb[\"bboxes\"]), len(lb[\"segments\"])) for lb in labels)\n len_cls, len_boxes, len_segments = (sum(x) for x in zip(*lengths))\n if len_segments and len_boxes != len_segments:\n LOGGER.warning(\n f\"Box and segment counts should be equal, but got len(segments) = {len_segments}, \"\n f\"len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. \"\n \"To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset.\"\n )\n for lb in labels:\n lb[\"segments\"] = []\n if len_cls == 0:\n LOGGER.warning(f\"Labels are missing or empty in {cache_path}, training may not work correctly. {HELP_URL}\")\n return labels\n\n def build_transforms(self, hyp: Optional[Dict] = None) -> Compose:\n \"\"\"\n Build and append transforms to the list.\n\n Args:\n hyp (dict, optional): Hyperparameters for transforms.\n\n Returns:\n (Compose): Composed transforms.\n \"\"\"\n if self.augment:\n hyp.mosaic = hyp.mosaic if self.augment and not self.rect else 0.0\n hyp.mixup = hyp.mixup if self.augment and not self.rect else 0.0\n hyp.cutmix = hyp.cutmix if self.augment and not self.rect else 0.0\n transforms = v8_transforms(self, self.imgsz, hyp)\n else:\n transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), scaleup=False)])\n transforms.append(\n Format(\n bbox_format=\"xywh\",\n normalize=True,\n return_mask=self.use_segments,\n return_keypoint=self.use_keypoints,\n return_obb=self.use_obb,\n batch_idx=True,\n mask_ratio=hyp.mask_ratio,\n mask_overlap=hyp.overlap_mask,\n bgr=hyp.bgr if self.augment else 0.0, # only affect training.\n )\n )\n return transforms\n\n def close_mosaic(self, hyp: Dict) -> None:\n \"\"\"\n Disable mosaic, copy_paste, mixup and cutmix augmentations by setting their probabilities to 0.0.\n\n Args:\n hyp (dict): Hyperparameters for transforms.\n \"\"\"\n hyp.mosaic = 0.0\n hyp.copy_paste = 0.0\n hyp.mixup = 0.0\n hyp.cutmix = 0.0\n self.transforms = self.build_transforms(hyp)\n\n def update_labels_info(self, label: Dict) -> Dict:\n \"\"\"\n Update label format for different tasks.\n\n Args:\n label (dict): Label dictionary containing bboxes, segments, keypoints, etc.\n\n Returns:\n (dict): Updated label dictionary with instances.\n\n Note:\n cls is not with bboxes now, classification and semantic segmentation need an independent cls label\n Can also support classification and semantic segmentation by adding or removing dict keys there.\n \"\"\"\n bboxes = label.pop(\"bboxes\")\n segments = label.pop(\"segments\", [])\n keypoints = label.pop(\"keypoints\", None)\n bbox_format = label.pop(\"bbox_format\")\n normalized = label.pop(\"normalized\")\n\n # NOTE: do NOT resample oriented boxes\n segment_resamples = 100 if self.use_obb else 1000\n if len(segments) > 0:\n # make sure segments interpolate correctly if original length is greater than segment_resamples\n max_len = max(len(s) for s in segments)\n segment_resamples = (max_len + 1) if segment_resamples < max_len else segment_resamples\n # list[np.array(segment_resamples, 2)] * num_samples\n segments = np.stack(resample_segments(segments, n=segment_resamples), axis=0)\n else:\n segments = np.zeros((0, segment_resamples, 2), dtype=np.float32)\n label[\"instances\"] = Instances(bboxes, segments, keypoints, bbox_format=bbox_format, normalized=normalized)\n return label\n\n @staticmethod\n def collate_fn(batch: List[Dict]) -> Dict:\n \"\"\"\n Collate data samples into batches.\n\n Args:\n batch (List[dict]): List of dictionaries containing sample data.\n\n Returns:\n (dict): Collated batch with stacked tensors.\n \"\"\"\n new_batch = {}\n batch = [dict(sorted(b.items())) for b in batch] # make sure the keys are in the same order\n keys = batch[0].keys()\n values = list(zip(*[list(b.values()) for b in batch]))\n for i, k in enumerate(keys):\n value = values[i]\n if k in {\"img\", \"text_feats\"}:\n value = torch.stack(value, 0)\n elif k == \"visuals\":\n value = torch.nn.utils.rnn.pad_sequence(value, batch_first=True)\n if k in {\"masks\", \"keypoints\", \"bboxes\", \"cls\", \"segments\", \"obb\"}:\n value = torch.cat(value, 0)\n new_batch[k] = value\n new_batch[\"batch_idx\"] = list(new_batch[\"batch_idx\"])\n for i in range(len(new_batch[\"batch_idx\"])):\n new_batch[\"batch_idx\"][i] += i # add target image index for build_targets()\n new_batch[\"batch_idx\"] = torch.cat(new_batch[\"batch_idx\"], 0)\n return new_batch", "chunk_type": "class", "name": "YOLODataset", "file_path": "ultralytics\\ultralytics\\data\\dataset.py", "start_line": 47, "end_line": 314, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": "Dataset class for loading object detection and/or segmentation labels in YOLO format.\n\nThis class supports loading data for object detection, segmentation, pose estimation, and oriented bounding box\n(OBB) tasks using the YOLO format.\n\nAttributes:\n use_segments (bool): Indicates if segmentation masks should be used.\n use_keypoints (bool): Indicates if keypoints should be used for pose estimation.\n use_obb (bool): Indicates if oriented bounding boxes should be used.\n data (dict): Dataset configuration dictionary.\n\nMethods:\n cache_labels: Cache dataset labels, check images and read shapes.\n get_labels: Return dictionary of labels for YOLO training.\n build_transforms: Build and append transforms to the list.\n close_mosaic: Set mosaic, copy_paste and mixup options to 0.0 and build transformations.\n update_labels_info: Update label format for different tasks.\n collate_fn: Collate data samples into batches.\n\nExamples:\n >>> dataset = YOLODataset(img_path=\"path/to/images\", data={\"names\": {0: \"person\"}}, task=\"detect\")\n >>> dataset.get_labels()", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "json", "collections.defaultdict", "itertools.repeat", "multiprocessing.pool.ThreadPool", "pathlib.Path", "typing.Any", "typing.Dict", "typing.List", "typing.Optional", "typing.Tuple", "cv2", "numpy", "torch", "PIL.Image", "torch.utils.data.ConcatDataset", "ultralytics.utils.LOCAL_RANK", "ultralytics.utils.LOGGER", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.TQDM", "ultralytics.utils.colorstr", "ultralytics.utils.instance.Instances", "ultralytics.utils.ops.resample_segments", "ultralytics.utils.ops.segments2boxes", "ultralytics.utils.torch_utils.TORCHVISION_0_18", "augment.Compose", "augment.Format", "augment.LetterBox", "augment.RandomLoadText", "augment.classify_augmentations", "augment.classify_transforms", "augment.v8_transforms", "base.BaseDataset", "converter.merge_multi_segment", "utils.HELP_URL", "utils.check_file_speeds", "utils.get_hash", "utils.img2label_paths", "utils.load_dataset_cache_file", "utils.save_dataset_cache_file", "utils.verify_image", "utils.verify_image_label", "torchvision", "BaseDataset" ], "chunk_id": "class_YOLODataset_d3781939" }, { "content": "class YOLOMultiModalDataset(YOLODataset):\n \"\"\"\n Dataset class for loading object detection and/or segmentation labels in YOLO format with multi-modal support.\n\n This class extends YOLODataset to add text information for multi-modal model training, enabling models to\n process both image and text data.\n\n Methods:\n update_labels_info: Add text information for multi-modal model training.\n build_transforms: Enhance data transformations with text augmentation.\n\n Examples:\n >>> dataset = YOLOMultiModalDataset(img_path=\"path/to/images\", data={\"names\": {0: \"person\"}}, task=\"detect\")\n >>> batch = next(iter(dataset))\n >>> print(batch.keys()) # Should include 'texts'\n \"\"\"\n\n def __init__(self, *args, data: Optional[Dict] = None, task: str = \"detect\", **kwargs):\n \"\"\"\n Initialize a YOLOMultiModalDataset.\n\n Args:\n data (dict, optional): Dataset configuration dictionary.\n task (str): Task type, one of 'detect', 'segment', 'pose', or 'obb'.\n *args (Any): Additional positional arguments for the parent class.\n **kwargs (Any): Additional keyword arguments for the parent class.\n \"\"\"\n super().__init__(*args, data=data, task=task, **kwargs)\n\n def update_labels_info(self, label: Dict) -> Dict:\n \"\"\"\n Add text information for multi-modal model training.\n\n Args:\n label (dict): Label dictionary containing bboxes, segments, keypoints, etc.\n\n Returns:\n (dict): Updated label dictionary with instances and texts.\n \"\"\"\n labels = super().update_labels_info(label)\n # NOTE: some categories are concatenated with its synonyms by `/`.\n # NOTE: and `RandomLoadText` would randomly select one of them if there are multiple words.\n labels[\"texts\"] = [v.split(\"/\") for _, v in self.data[\"names\"].items()]\n\n return labels\n\n def build_transforms(self, hyp: Optional[Dict] = None) -> Compose:\n \"\"\"\n Enhance data transformations with optional text augmentation for multi-modal training.\n\n Args:\n hyp (dict, optional): Hyperparameters for transforms.\n\n Returns:\n (Compose): Composed transforms including text augmentation if applicable.\n \"\"\"\n transforms = super().build_transforms(hyp)\n if self.augment:\n # NOTE: hard-coded the args for now.\n # NOTE: this implementation is different from official yoloe,\n # the strategy of selecting negative is restricted in one dataset,\n # while official pre-saved neg embeddings from all datasets at once.\n transform = RandomLoadText(\n max_samples=min(self.data[\"nc\"], 80),\n padding=True,\n padding_value=self._get_neg_texts(self.category_freq),\n )\n transforms.insert(-1, transform)\n return transforms\n\n @property\n def category_names(self):\n \"\"\"\n Return category names for the dataset.\n\n Returns:\n (Set[str]): List of class names.\n \"\"\"\n names = self.data[\"names\"].values()\n return {n.strip() for name in names for n in name.split(\"/\")} # category names\n\n @property\n def category_freq(self):\n \"\"\"Return frequency of each category in the dataset.\"\"\"\n texts = [v.split(\"/\") for v in self.data[\"names\"].values()]\n category_freq = defaultdict(int)\n for label in self.labels:\n for c in label[\"cls\"].squeeze(-1): # to check\n text = texts[int(c)]\n for t in text:\n t = t.strip()\n category_freq[t] += 1\n return category_freq\n\n @staticmethod\n def _get_neg_texts(category_freq: Dict, threshold: int = 100) -> List[str]:\n \"\"\"Get negative text samples based on frequency threshold.\"\"\"\n threshold = min(max(category_freq.values()), 100)\n return [k for k, v in category_freq.items() if v >= threshold]", "chunk_type": "class", "name": "YOLOMultiModalDataset", "file_path": "ultralytics\\ultralytics\\data\\dataset.py", "start_line": 317, "end_line": 415, "start_col": 0, "end_col": 70, "parent_name": null, "docstring": "Dataset class for loading object detection and/or segmentation labels in YOLO format with multi-modal support.\n\nThis class extends YOLODataset to add text information for multi-modal model training, enabling models to\nprocess both image and text data.\n\nMethods:\n update_labels_info: Add text information for multi-modal model training.\n build_transforms: Enhance data transformations with text augmentation.\n\nExamples:\n >>> dataset = YOLOMultiModalDataset(img_path=\"path/to/images\", data={\"names\": {0: \"person\"}}, task=\"detect\")\n >>> batch = next(iter(dataset))\n >>> print(batch.keys()) # Should include 'texts'", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "json", "collections.defaultdict", "itertools.repeat", "multiprocessing.pool.ThreadPool", "pathlib.Path", "typing.Any", "typing.Dict", "typing.List", "typing.Optional", "typing.Tuple", "cv2", "numpy", "torch", "PIL.Image", "torch.utils.data.ConcatDataset", "ultralytics.utils.LOCAL_RANK", "ultralytics.utils.LOGGER", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.TQDM", "ultralytics.utils.colorstr", "ultralytics.utils.instance.Instances", "ultralytics.utils.ops.resample_segments", "ultralytics.utils.ops.segments2boxes", "ultralytics.utils.torch_utils.TORCHVISION_0_18", "augment.Compose", "augment.Format", "augment.LetterBox", "augment.RandomLoadText", "augment.classify_augmentations", "augment.classify_transforms", "augment.v8_transforms", "base.BaseDataset", "converter.merge_multi_segment", "utils.HELP_URL", "utils.check_file_speeds", "utils.get_hash", "utils.img2label_paths", "utils.load_dataset_cache_file", "utils.save_dataset_cache_file", "utils.verify_image", "utils.verify_image_label", "torchvision", "YOLODataset" ], "chunk_id": "class_YOLOMultiModalDataset_a043cac4" }, { "content": "class GroundingDataset(YOLODataset):\n \"\"\"\n Dataset class for object detection tasks using annotations from a JSON file in grounding format.\n\n This dataset is designed for grounding tasks where annotations are provided in a JSON file rather than\n the standard YOLO format text files.\n\n Attributes:\n json_file (str): Path to the JSON file containing annotations.\n\n Methods:\n get_img_files: Return empty list as image files are read in get_labels.\n get_labels: Load annotations from a JSON file and prepare them for training.\n build_transforms: Configure augmentations for training with optional text loading.\n\n Examples:\n >>> dataset = GroundingDataset(img_path=\"path/to/images\", json_file=\"annotations.json\", task=\"detect\")\n >>> len(dataset) # Number of valid images with annotations\n \"\"\"\n\n def __init__(self, *args, task: str = \"detect\", json_file: str = \"\", max_samples: int = 80, **kwargs):\n \"\"\"\n Initialize a GroundingDataset for object detection.\n\n Args:\n json_file (str): Path to the JSON file containing annotations.\n task (str): Must be 'detect' or 'segment' for GroundingDataset.\n max_samples (int): Maximum number of samples to load for text augmentation.\n *args (Any): Additional positional arguments for the parent class.\n **kwargs (Any): Additional keyword arguments for the parent class.\n \"\"\"\n assert task in {\"detect\", \"segment\"}, \"GroundingDataset currently only supports `detect` and `segment` tasks\"\n self.json_file = json_file\n self.max_samples = max_samples\n super().__init__(*args, task=task, data={\"channels\": 3}, **kwargs)\n\n def get_img_files(self, img_path: str) -> List:\n \"\"\"\n The image files would be read in `get_labels` function, return empty list here.\n\n Args:\n img_path (str): Path to the directory containing images.\n\n Returns:\n (list): Empty list as image files are read in get_labels.\n \"\"\"\n return []\n\n def verify_labels(self, labels: List[Dict[str, Any]]) -> None:\n \"\"\"\n Verify the number of instances in the dataset matches expected counts.\n\n This method checks if the total number of bounding box instances in the provided\n labels matches the expected count for known datasets. It performs validation\n against a predefined set of datasets with known instance counts.\n\n Args:\n labels (List[Dict[str, Any]]): List of label dictionaries, where each dictionary\n contains dataset annotations. Each label dict must have a 'bboxes' key with\n a numpy array or tensor containing bounding box coordinates.\n\n Raises:\n AssertionError: If the actual instance count doesn't match the expected count\n for a recognized dataset.\n\n Note:\n For unrecognized datasets (those not in the predefined expected_counts),\n a warning is logged and verification is skipped.\n \"\"\"\n expected_counts = {\n \"final_mixed_train_no_coco_segm\": 3662412,\n \"final_mixed_train_no_coco\": 3681235,\n \"final_flickr_separateGT_train_segm\": 638214,\n \"final_flickr_separateGT_train\": 640704,\n }\n\n instance_count = sum(label[\"bboxes\"].shape[0] for label in labels)\n for data_name, count in expected_counts.items():\n if data_name in self.json_file:\n assert instance_count == count, f\"'{self.json_file}' has {instance_count} instances, expected {count}.\"\n return\n LOGGER.warning(f\"Skipping instance count verification for unrecognized dataset '{self.json_file}'\")\n\n def cache_labels(self, path: Path = Path(\"./labels.cache\")) -> Dict[str, Any]:\n \"\"\"\n Load annotations from a JSON file, filter, and normalize bounding boxes for each image.\n\n Args:\n path (Path): Path where to save the cache file.\n\n Returns:\n (Dict[str, Any]): Dictionary containing cached labels and related information.\n \"\"\"\n x = {\"labels\": []}\n LOGGER.info(\"Loading annotation file...\")\n with open(self.json_file) as f:\n annotations = json.load(f)\n images = {f\"{x['id']:d}\": x for x in annotations[\"images\"]}\n img_to_anns = defaultdict(list)\n for ann in annotations[\"annotations\"]:\n img_to_anns[ann[\"image_id\"]].append(ann)\n for img_id, anns in TQDM(img_to_anns.items(), desc=f\"Reading annotations {self.json_file}\"):\n img = images[f\"{img_id:d}\"]\n h, w, f = img[\"height\"], img[\"width\"], img[\"file_name\"]\n im_file = Path(self.img_path) / f\n if not im_file.exists():\n continue\n self.im_files.append(str(im_file))\n bboxes = []\n segments = []\n cat2id = {}\n texts = []\n for ann in anns:\n if ann[\"iscrowd\"]:\n continue\n box = np.array(ann[\"bbox\"], dtype=np.float32)\n box[:2] += box[2:] / 2\n box[[0, 2]] /= float(w)\n box[[1, 3]] /= float(h)\n if box[2] <= 0 or box[3] <= 0:\n continue\n\n caption = img[\"caption\"]\n cat_name = \" \".join([caption[t[0] : t[1]] for t in ann[\"tokens_positive\"]]).lower().strip()\n if not cat_name:\n continue\n\n if cat_name not in cat2id:\n cat2id[cat_name] = len(cat2id)\n texts.append([cat_name])\n cls = cat2id[cat_name] # class\n box = [cls] + box.tolist()\n if box not in bboxes:\n bboxes.append(box)\n if ann.get(\"segmentation\") is not None:\n if len(ann[\"segmentation\"]) == 0:\n segments.append(box)\n continue\n elif len(ann[\"segmentation\"]) > 1:\n s = merge_multi_segment(ann[\"segmentation\"])\n s = (np.concatenate(s, axis=0) / np.array([w, h], dtype=np.float32)).reshape(-1).tolist()\n else:\n s = [j for i in ann[\"segmentation\"] for j in i] # all segments concatenated\n s = (\n (np.array(s, dtype=np.float32).reshape(-1, 2) / np.array([w, h], dtype=np.float32))\n .reshape(-1)\n .tolist()\n )\n s = [cls] + s\n segments.append(s)\n lb = np.array(bboxes, dtype=np.float32) if len(bboxes) else np.zeros((0, 5), dtype=np.float32)\n\n if segments:\n classes = np.array([x[0] for x in segments], dtype=np.float32)\n segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in segments] # (cls, xy1...)\n lb = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh)\n lb = np.array(lb, dtype=np.float32)\n\n x[\"labels\"].append(\n {\n \"im_file\": im_file,\n \"shape\": (h, w),\n \"cls\": lb[:, 0:1], # n, 1\n \"bboxes\": lb[:, 1:], # n, 4\n \"segments\": segments,\n \"normalized\": True,\n \"bbox_format\": \"xywh\",\n \"texts\": texts,\n }\n )\n x[\"hash\"] = get_hash(self.json_file)\n save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)\n return x\n\n def get_labels(self) -> List[Dict]:\n \"\"\"\n Load labels from cache or generate them from JSON file.\n\n Returns:\n (List[dict]): List of label dictionaries, each containing information about an image and its annotations.\n \"\"\"\n cache_path = Path(self.json_file).with_suffix(\".cache\")\n try:\n cache, _ = load_dataset_cache_file(cache_path), True # attempt to load a *.cache file\n assert cache[\"version\"] == DATASET_CACHE_VERSION # matches current version\n assert cache[\"hash\"] == get_hash(self.json_file) # identical hash\n except (FileNotFoundError, AssertionError, AttributeError, ModuleNotFoundError):\n cache, _ = self.cache_labels(cache_path), False # run cache ops\n [cache.pop(k) for k in (\"hash\", \"version\")] # remove items\n labels = cache[\"labels\"]\n self.verify_labels(labels)\n self.im_files = [str(label[\"im_file\"]) for label in labels]\n if LOCAL_RANK in {-1, 0}:\n LOGGER.info(f\"Load {self.json_file} from cache file {cache_path}\")\n return labels\n\n def build_transforms(self, hyp: Optional[Dict] = None) -> Compose:\n \"\"\"\n Configure augmentations for training with optional text loading.\n\n Args:\n hyp (dict, optional): Hyperparameters for transforms.\n\n Returns:\n (Compose): Composed transforms including text augmentation if applicable.\n \"\"\"\n transforms = super().build_transforms(hyp)\n if self.augment:\n # NOTE: hard-coded the args for now.\n # NOTE: this implementation is different from official yoloe,\n # the strategy of selecting negative is restricted in one dataset,\n # while official pre-saved neg embeddings from all datasets at once.\n transform = RandomLoadText(\n max_samples=min(self.max_samples, 80),\n padding=True,\n padding_value=self._get_neg_texts(self.category_freq),\n )\n transforms.insert(-1, transform)\n return transforms\n\n @property\n def category_names(self):\n \"\"\"Return unique category names from the dataset.\"\"\"\n return {t.strip() for label in self.labels for text in label[\"texts\"] for t in text}\n\n @property\n def category_freq(self):\n \"\"\"Return frequency of each category in the dataset.\"\"\"\n category_freq = defaultdict(int)\n for label in self.labels:\n for text in label[\"texts\"]:\n for t in text:\n t = t.strip()\n category_freq[t] += 1\n return category_freq\n\n @staticmethod\n def _get_neg_texts(category_freq: Dict, threshold: int = 100) -> List[str]:\n \"\"\"Get negative text samples based on frequency threshold.\"\"\"\n threshold = min(max(category_freq.values()), 100)\n return [k for k, v in category_freq.items() if v >= threshold]", "chunk_type": "class", "name": "GroundingDataset", "file_path": "ultralytics\\ultralytics\\data\\dataset.py", "start_line": 418, "end_line": 658, "start_col": 0, "end_col": 70, "parent_name": null, "docstring": "Dataset class for object detection tasks using annotations from a JSON file in grounding format.\n\nThis dataset is designed for grounding tasks where annotations are provided in a JSON file rather than\nthe standard YOLO format text files.\n\nAttributes:\n json_file (str): Path to the JSON file containing annotations.\n\nMethods:\n get_img_files: Return empty list as image files are read in get_labels.\n get_labels: Load annotations from a JSON file and prepare them for training.\n build_transforms: Configure augmentations for training with optional text loading.\n\nExamples:\n >>> dataset = GroundingDataset(img_path=\"path/to/images\", json_file=\"annotations.json\", task=\"detect\")\n >>> len(dataset) # Number of valid images with annotations", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "json", "collections.defaultdict", "itertools.repeat", "multiprocessing.pool.ThreadPool", "pathlib.Path", "typing.Any", "typing.Dict", "typing.List", "typing.Optional", "typing.Tuple", "cv2", "numpy", "torch", "PIL.Image", "torch.utils.data.ConcatDataset", "ultralytics.utils.LOCAL_RANK", "ultralytics.utils.LOGGER", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.TQDM", "ultralytics.utils.colorstr", "ultralytics.utils.instance.Instances", "ultralytics.utils.ops.resample_segments", "ultralytics.utils.ops.segments2boxes", "ultralytics.utils.torch_utils.TORCHVISION_0_18", "augment.Compose", "augment.Format", "augment.LetterBox", "augment.RandomLoadText", "augment.classify_augmentations", "augment.classify_transforms", "augment.v8_transforms", "base.BaseDataset", "converter.merge_multi_segment", "utils.HELP_URL", "utils.check_file_speeds", "utils.get_hash", "utils.img2label_paths", "utils.load_dataset_cache_file", "utils.save_dataset_cache_file", "utils.verify_image", "utils.verify_image_label", "torchvision", "YOLODataset" ], "chunk_id": "class_GroundingDataset_b46f1269" }, { "content": "class YOLOConcatDataset(ConcatDataset):\n \"\"\"\n Dataset as a concatenation of multiple datasets.\n\n This class is useful to assemble different existing datasets for YOLO training, ensuring they use the same\n collation function.\n\n Methods:\n collate_fn: Static method that collates data samples into batches using YOLODataset's collation function.\n\n Examples:\n >>> dataset1 = YOLODataset(...)\n >>> dataset2 = YOLODataset(...)\n >>> combined_dataset = YOLOConcatDataset([dataset1, dataset2])\n \"\"\"\n\n @staticmethod\n def collate_fn(batch: List[Dict]) -> Dict:\n \"\"\"\n Collate data samples into batches.\n\n Args:\n batch (List[dict]): List of dictionaries containing sample data.\n\n Returns:\n (dict): Collated batch with stacked tensors.\n \"\"\"\n return YOLODataset.collate_fn(batch)\n\n def close_mosaic(self, hyp: Dict) -> None:\n \"\"\"\n Set mosaic, copy_paste and mixup options to 0.0 and build transformations.\n\n Args:\n hyp (dict): Hyperparameters for transforms.\n \"\"\"\n for dataset in self.datasets:\n if not hasattr(dataset, \"close_mosaic\"):\n continue\n dataset.close_mosaic(hyp)", "chunk_type": "class", "name": "YOLOConcatDataset", "file_path": "ultralytics\\ultralytics\\data\\dataset.py", "start_line": 661, "end_line": 700, "start_col": 0, "end_col": 37, "parent_name": null, "docstring": "Dataset as a concatenation of multiple datasets.\n\nThis class is useful to assemble different existing datasets for YOLO training, ensuring they use the same\ncollation function.\n\nMethods:\n collate_fn: Static method that collates data samples into batches using YOLODataset's collation function.\n\nExamples:\n >>> dataset1 = YOLODataset(...)\n >>> dataset2 = YOLODataset(...)\n >>> combined_dataset = YOLOConcatDataset([dataset1, dataset2])", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "json", "collections.defaultdict", "itertools.repeat", "multiprocessing.pool.ThreadPool", "pathlib.Path", "typing.Any", "typing.Dict", "typing.List", "typing.Optional", "typing.Tuple", "cv2", "numpy", "torch", "PIL.Image", "torch.utils.data.ConcatDataset", "ultralytics.utils.LOCAL_RANK", "ultralytics.utils.LOGGER", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.TQDM", "ultralytics.utils.colorstr", "ultralytics.utils.instance.Instances", "ultralytics.utils.ops.resample_segments", "ultralytics.utils.ops.segments2boxes", "ultralytics.utils.torch_utils.TORCHVISION_0_18", "augment.Compose", "augment.Format", "augment.LetterBox", "augment.RandomLoadText", "augment.classify_augmentations", "augment.classify_transforms", "augment.v8_transforms", "base.BaseDataset", "converter.merge_multi_segment", "utils.HELP_URL", "utils.check_file_speeds", "utils.get_hash", "utils.img2label_paths", "utils.load_dataset_cache_file", "utils.save_dataset_cache_file", "utils.verify_image", "utils.verify_image_label", "torchvision", "ConcatDataset" ], "chunk_id": "class_YOLOConcatDataset_e00e68fb" }, { "content": "class SemanticDataset(BaseDataset):\n \"\"\"Semantic Segmentation Dataset.\"\"\"\n\n def __init__(self):\n \"\"\"Initialize a SemanticDataset object.\"\"\"\n super().__init__()", "chunk_type": "class", "name": "SemanticDataset", "file_path": "ultralytics\\ultralytics\\data\\dataset.py", "start_line": 704, "end_line": 709, "start_col": 0, "end_col": 26, "parent_name": null, "docstring": "Semantic Segmentation Dataset.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "json", "collections.defaultdict", "itertools.repeat", "multiprocessing.pool.ThreadPool", "pathlib.Path", "typing.Any", "typing.Dict", "typing.List", "typing.Optional", "typing.Tuple", "cv2", "numpy", "torch", "PIL.Image", "torch.utils.data.ConcatDataset", "ultralytics.utils.LOCAL_RANK", "ultralytics.utils.LOGGER", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.TQDM", "ultralytics.utils.colorstr", "ultralytics.utils.instance.Instances", "ultralytics.utils.ops.resample_segments", "ultralytics.utils.ops.segments2boxes", "ultralytics.utils.torch_utils.TORCHVISION_0_18", "augment.Compose", "augment.Format", "augment.LetterBox", "augment.RandomLoadText", "augment.classify_augmentations", "augment.classify_transforms", "augment.v8_transforms", "base.BaseDataset", "converter.merge_multi_segment", "utils.HELP_URL", "utils.check_file_speeds", "utils.get_hash", "utils.img2label_paths", "utils.load_dataset_cache_file", "utils.save_dataset_cache_file", "utils.verify_image", "utils.verify_image_label", "torchvision", "BaseDataset" ], "chunk_id": "class_SemanticDataset_f05eb6af" }, { "content": "class ClassificationDataset:\n \"\"\"\n Dataset class for image classification tasks extending torchvision ImageFolder functionality.\n\n This class offers functionalities like image augmentation, caching, and verification. It's designed to efficiently\n handle large datasets for training deep learning models, with optional image transformations and caching mechanisms\n to speed up training.\n\n Attributes:\n cache_ram (bool): Indicates if caching in RAM is enabled.\n cache_disk (bool): Indicates if caching on disk is enabled.\n samples (list): A list of tuples, each containing the path to an image, its class index, path to its .npy cache\n file (if caching on disk), and optionally the loaded image array (if caching in RAM).\n torch_transforms (callable): PyTorch transforms to be applied to the images.\n root (str): Root directory of the dataset.\n prefix (str): Prefix for logging and cache filenames.\n\n Methods:\n __getitem__: Return subset of data and targets corresponding to given indices.\n __len__: Return the total number of samples in the dataset.\n verify_images: Verify all images in dataset.\n \"\"\"\n\n def __init__(self, root: str, args, augment: bool = False, prefix: str = \"\"):\n \"\"\"\n Initialize YOLO classification dataset with root directory, arguments, augmentations, and cache settings.\n\n Args:\n root (str): Path to the dataset directory where images are stored in a class-specific folder structure.\n args (Namespace): Configuration containing dataset-related settings such as image size, augmentation\n parameters, and cache settings.\n augment (bool, optional): Whether to apply augmentations to the dataset.\n prefix (str, optional): Prefix for logging and cache filenames, aiding in dataset identification.\n \"\"\"\n import torchvision # scope for faster 'import ultralytics'\n\n # Base class assigned as attribute rather than used as base class to allow for scoping slow torchvision import\n if TORCHVISION_0_18: # 'allow_empty' argument first introduced in torchvision 0.18\n self.base = torchvision.datasets.ImageFolder(root=root, allow_empty=True)\n else:\n self.base = torchvision.datasets.ImageFolder(root=root)\n self.samples = self.base.samples\n self.root = self.base.root\n\n # Initialize attributes\n if augment and args.fraction < 1.0: # reduce training fraction\n self.samples = self.samples[: round(len(self.samples) * args.fraction)]\n self.prefix = colorstr(f\"{prefix}: \") if prefix else \"\"\n self.cache_ram = args.cache is True or str(args.cache).lower() == \"ram\" # cache images into RAM\n if self.cache_ram:\n LOGGER.warning(\n \"Classification `cache_ram` training has known memory leak in \"\n \"https://github.com/ultralytics/ultralytics/issues/9824, setting `cache_ram=False`.\"\n )\n self.cache_ram = False\n self.cache_disk = str(args.cache).lower() == \"disk\" # cache images on hard drive as uncompressed *.npy files\n self.samples = self.verify_images() # filter out bad images\n self.samples = [list(x) + [Path(x[0]).with_suffix(\".npy\"), None] for x in self.samples] # file, index, npy, im\n scale = (1.0 - args.scale, 1.0) # (0.08, 1.0)\n self.torch_transforms = (\n classify_augmentations(\n size=args.imgsz,\n scale=scale,\n hflip=args.fliplr,\n vflip=args.flipud,\n erasing=args.erasing,\n auto_augment=args.auto_augment,\n hsv_h=args.hsv_h,\n hsv_s=args.hsv_s,\n hsv_v=args.hsv_v,\n )\n if augment\n else classify_transforms(size=args.imgsz)\n )\n\n def __getitem__(self, i: int) -> Dict:\n \"\"\"\n Return subset of data and targets corresponding to given indices.\n\n Args:\n i (int): Index of the sample to retrieve.\n\n Returns:\n (dict): Dictionary containing the image and its class index.\n \"\"\"\n f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image\n if self.cache_ram:\n if im is None: # Warning: two separate if statements required here, do not combine this with previous line\n im = self.samples[i][3] = cv2.imread(f)\n elif self.cache_disk:\n if not fn.exists(): # load npy\n np.save(fn.as_posix(), cv2.imread(f), allow_pickle=False)\n im = np.load(fn)\n else: # read image\n im = cv2.imread(f) # BGR\n # Convert NumPy array to PIL image\n im = Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))\n sample = self.torch_transforms(im)\n return {\"img\": sample, \"cls\": j}\n\n def __len__(self) -> int:\n \"\"\"Return the total number of samples in the dataset.\"\"\"\n return len(self.samples)\n\n def verify_images(self) -> List[Tuple]:\n \"\"\"\n Verify all images in dataset.\n\n Returns:\n (list): List of valid samples after verification.\n \"\"\"\n desc = f\"{self.prefix}Scanning {self.root}...\"\n path = Path(self.root).with_suffix(\".cache\") # *.cache file path\n\n try:\n check_file_speeds([file for (file, _) in self.samples[:5]], prefix=self.prefix) # check image read speeds\n cache = load_dataset_cache_file(path) # attempt to load a *.cache file\n assert cache[\"version\"] == DATASET_CACHE_VERSION # matches current version\n assert cache[\"hash\"] == get_hash([x[0] for x in self.samples]) # identical hash\n nf, nc, n, samples = cache.pop(\"results\") # found, missing, empty, corrupt, total\n if LOCAL_RANK in {-1, 0}:\n d = f\"{desc} {nf} images, {nc} corrupt\"\n TQDM(None, desc=d, total=n, initial=n)\n if cache[\"msgs\"]:\n LOGGER.info(\"\\n\".join(cache[\"msgs\"])) # display warnings\n return samples\n\n except (FileNotFoundError, AssertionError, AttributeError):\n # Run scan if *.cache retrieval failed\n nf, nc, msgs, samples, x = 0, 0, [], [], {}\n with ThreadPool(NUM_THREADS) as pool:\n results = pool.imap(func=verify_image, iterable=zip(self.samples, repeat(self.prefix)))\n pbar = TQDM(results, desc=desc, total=len(self.samples))\n for sample, nf_f, nc_f, msg in pbar:\n if nf_f:\n samples.append(sample)\n if msg:\n msgs.append(msg)\n nf += nf_f\n nc += nc_f\n pbar.desc = f\"{desc} {nf} images, {nc} corrupt\"\n pbar.close()\n if msgs:\n LOGGER.info(\"\\n\".join(msgs))\n x[\"hash\"] = get_hash([x[0] for x in self.samples])\n x[\"results\"] = nf, nc, len(samples), samples\n x[\"msgs\"] = msgs # warnings\n save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)\n return samples", "chunk_type": "class", "name": "ClassificationDataset", "file_path": "ultralytics\\ultralytics\\data\\dataset.py", "start_line": 712, "end_line": 860, "start_col": 0, "end_col": 26, "parent_name": null, "docstring": "Dataset class for image classification tasks extending torchvision ImageFolder functionality.\n\nThis class offers functionalities like image augmentation, caching, and verification. It's designed to efficiently\nhandle large datasets for training deep learning models, with optional image transformations and caching mechanisms\nto speed up training.\n\nAttributes:\n cache_ram (bool): Indicates if caching in RAM is enabled.\n cache_disk (bool): Indicates if caching on disk is enabled.\n samples (list): A list of tuples, each containing the path to an image, its class index, path to its .npy cache\n file (if caching on disk), and optionally the loaded image array (if caching in RAM).\n torch_transforms (callable): PyTorch transforms to be applied to the images.\n root (str): Root directory of the dataset.\n prefix (str): Prefix for logging and cache filenames.\n\nMethods:\n __getitem__: Return subset of data and targets corresponding to given indices.\n __len__: Return the total number of samples in the dataset.\n verify_images: Verify all images in dataset.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "json", "collections.defaultdict", "itertools.repeat", "multiprocessing.pool.ThreadPool", "pathlib.Path", "typing.Any", "typing.Dict", "typing.List", "typing.Optional", "typing.Tuple", "cv2", "numpy", "torch", "PIL.Image", "torch.utils.data.ConcatDataset", "ultralytics.utils.LOCAL_RANK", "ultralytics.utils.LOGGER", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.TQDM", "ultralytics.utils.colorstr", "ultralytics.utils.instance.Instances", "ultralytics.utils.ops.resample_segments", "ultralytics.utils.ops.segments2boxes", "ultralytics.utils.torch_utils.TORCHVISION_0_18", "augment.Compose", "augment.Format", "augment.LetterBox", "augment.RandomLoadText", "augment.classify_augmentations", "augment.classify_transforms", "augment.v8_transforms", "base.BaseDataset", "converter.merge_multi_segment", "utils.HELP_URL", "utils.check_file_speeds", "utils.get_hash", "utils.img2label_paths", "utils.load_dataset_cache_file", "utils.save_dataset_cache_file", "utils.verify_image", "utils.verify_image_label", "torchvision" ], "chunk_id": "class_ClassificationDataset_b31e4881" }, { "content": "import glob", "chunk_type": "import", "name": "glob", "file_path": "ultralytics\\ultralytics\\data\\loaders.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 11, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_glob_f0db6417" }, { "content": "import math", "chunk_type": "import", "name": "math", "file_path": "ultralytics\\ultralytics\\data\\loaders.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 11, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_math_dfc0a24b" }, { "content": "import os", "chunk_type": "import", "name": "os", "file_path": "ultralytics\\ultralytics\\data\\loaders.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 9, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_os_6aed6483" }, { "content": "import time", "chunk_type": "import", "name": "time", "file_path": "ultralytics\\ultralytics\\data\\loaders.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 11, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_time_1b546745" }, { "content": "import urllib", "chunk_type": "import", "name": "urllib", "file_path": "ultralytics\\ultralytics\\data\\loaders.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 13, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_urllib_690426d3" }, { "content": "from dataclasses import dataclass", "chunk_type": "import", "name": "dataclass", "file_path": "ultralytics\\ultralytics\\data\\loaders.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 33, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_dataclass_c6078cfd" }, { "content": "from pathlib import Path", "chunk_type": "import", "name": "Path", "file_path": "ultralytics\\ultralytics\\data\\loaders.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Path_fe6ef926" }, { "content": "from threading import Thread", "chunk_type": "import", "name": "Thread", "file_path": "ultralytics\\ultralytics\\data\\loaders.py", "start_line": 10, "end_line": 10, "start_col": 0, "end_col": 28, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Thread_aaf6f1ab" }, { "content": "from typing import Any, List, Optional, Tuple, Union", "chunk_type": "import", "name": "Any, List, Optional, Tuple, Union", "file_path": "ultralytics\\ultralytics\\data\\loaders.py", "start_line": 11, "end_line": 11, "start_col": 0, "end_col": 52, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Any, List, Optional, Tuple, Union_eadcbf08" }, { "content": "import cv2", "chunk_type": "import", "name": "cv2", "file_path": "ultralytics\\ultralytics\\data\\loaders.py", "start_line": 13, "end_line": 13, "start_col": 0, "end_col": 10, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_cv2_dbd368f3" }, { "content": "import numpy as np", "chunk_type": "import", "name": "numpy", "file_path": "ultralytics\\ultralytics\\data\\loaders.py", "start_line": 14, "end_line": 14, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_numpy_e811657e" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\ultralytics\\data\\loaders.py", "start_line": 15, "end_line": 15, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_abfb6a52" }, { "content": "from PIL import Image", "chunk_type": "import", "name": "Image", "file_path": "ultralytics\\ultralytics\\data\\loaders.py", "start_line": 16, "end_line": 16, "start_col": 0, "end_col": 21, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Image_caef6ca8" }, { "content": "from ultralytics.data.utils import FORMATS_HELP_MSG, IMG_FORMATS, VID_FORMATS", "chunk_type": "import", "name": "FORMATS_HELP_MSG, IMG_FORMATS, VID_FORMATS", "file_path": "ultralytics\\ultralytics\\data\\loaders.py", "start_line": 18, "end_line": 18, "start_col": 0, "end_col": 77, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_FORMATS_HELP_MSG, IMG_FORMATS, VID_FORMATS_a4ad2f01" }, { "content": "from ultralytics.utils import IS_COLAB, IS_KAGGLE, LOGGER, ops", "chunk_type": "import", "name": "IS_COLAB, IS_KAGGLE, LOGGER, ops", "file_path": "ultralytics\\ultralytics\\data\\loaders.py", "start_line": 19, "end_line": 19, "start_col": 0, "end_col": 62, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_IS_COLAB, IS_KAGGLE, LOGGER, ops_1ef4d325" }, { "content": "from ultralytics.utils.checks import check_requirements", "chunk_type": "import", "name": "check_requirements", "file_path": "ultralytics\\ultralytics\\data\\loaders.py", "start_line": 20, "end_line": 20, "start_col": 0, "end_col": 55, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_check_requirements_bc04ee8e" }, { "content": "from ultralytics.utils.patches import imread", "chunk_type": "import", "name": "imread", "file_path": "ultralytics\\ultralytics\\data\\loaders.py", "start_line": 21, "end_line": 21, "start_col": 0, "end_col": 44, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_imread_2accd0a8" }, { "content": "class SourceTypes:\n \"\"\"\n Class to represent various types of input sources for predictions.\n\n This class uses dataclass to define boolean flags for different types of input sources that can be used for\n making predictions with YOLO models.\n\n Attributes:\n stream (bool): Flag indicating if the input source is a video stream.\n screenshot (bool): Flag indicating if the input source is a screenshot.\n from_img (bool): Flag indicating if the input source is an image file.\n tensor (bool): Flag indicating if the input source is a tensor.\n\n Examples:\n >>> source_types = SourceTypes(stream=True, screenshot=False, from_img=False)\n >>> print(source_types.stream)\n True\n >>> print(source_types.from_img)\n False\n \"\"\"\n\n stream: bool = False\n screenshot: bool = False\n from_img: bool = False\n tensor: bool = False", "chunk_type": "class", "name": "SourceTypes", "file_path": "ultralytics\\ultralytics\\data\\loaders.py", "start_line": 25, "end_line": 49, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": "Class to represent various types of input sources for predictions.\n\nThis class uses dataclass to define boolean flags for different types of input sources that can be used for\nmaking predictions with YOLO models.\n\nAttributes:\n stream (bool): Flag indicating if the input source is a video stream.\n screenshot (bool): Flag indicating if the input source is a screenshot.\n from_img (bool): Flag indicating if the input source is an image file.\n tensor (bool): Flag indicating if the input source is a tensor.\n\nExamples:\n >>> source_types = SourceTypes(stream=True, screenshot=False, from_img=False)\n >>> print(source_types.stream)\n True\n >>> print(source_types.from_img)\n False", "parameters": null, "return_type": null, "decorators": [ "dataclass" ], "complexity_score": null, "dependencies": [ "glob", "math", "os", "time", "urllib", "dataclasses.dataclass", "pathlib.Path", "threading.Thread", "typing.Any", "typing.List", "typing.Optional", "typing.Tuple", "typing.Union", "cv2", "numpy", "torch", "PIL.Image", "ultralytics.data.utils.FORMATS_HELP_MSG", "ultralytics.data.utils.IMG_FORMATS", "ultralytics.data.utils.VID_FORMATS", "ultralytics.utils.IS_COLAB", "ultralytics.utils.IS_KAGGLE", "ultralytics.utils.LOGGER", "ultralytics.utils.ops", "ultralytics.utils.checks.check_requirements", "ultralytics.utils.patches.imread", "mss", "pytubefix.YouTube", "pafy", "yt_dlp", "pi_heif.register_heif_opener" ], "chunk_id": "class_SourceTypes_0edce07d" }, { "content": "class LoadStreams:\n \"\"\"\n Stream Loader for various types of video streams.\n\n Supports RTSP, RTMP, HTTP, and TCP streams. This class handles the loading and processing of multiple video\n streams simultaneously, making it suitable for real-time video analysis tasks.\n\n Attributes:\n sources (List[str]): The source input paths or URLs for the video streams.\n vid_stride (int): Video frame-rate stride.\n buffer (bool): Whether to buffer input streams.\n running (bool): Flag to indicate if the streaming thread is running.\n mode (str): Set to 'stream' indicating real-time capture.\n imgs (List[List[np.ndarray]]): List of image frames for each stream.\n fps (List[float]): List of FPS for each stream.\n frames (List[int]): List of total frames for each stream.\n threads (List[Thread]): List of threads for each stream.\n shape (List[Tuple[int, int, int]]): List of shapes for each stream.\n caps (List[cv2.VideoCapture]): List of cv2.VideoCapture objects for each stream.\n bs (int): Batch size for processing.\n cv2_flag (int): OpenCV flag for image reading (grayscale or RGB).\n\n Methods:\n update: Read stream frames in daemon thread.\n close: Close stream loader and release resources.\n __iter__: Returns an iterator object for the class.\n __next__: Returns source paths, transformed, and original images for processing.\n __len__: Return the length of the sources object.\n\n Examples:\n >>> stream_loader = LoadStreams(\"rtsp://example.com/stream1.mp4\")\n >>> for sources, imgs, _ in stream_loader:\n ... # Process the images\n ... pass\n >>> stream_loader.close()\n\n Notes:\n - The class uses threading to efficiently load frames from multiple streams simultaneously.\n - It automatically handles YouTube links, converting them to the best available stream URL.\n - The class implements a buffer system to manage frame storage and retrieval.\n \"\"\"\n\n def __init__(self, sources: str = \"file.streams\", vid_stride: int = 1, buffer: bool = False, channels: int = 3):\n \"\"\"\n Initialize stream loader for multiple video sources, supporting various stream types.\n\n Args:\n sources (str): Path to streams file or single stream URL.\n vid_stride (int): Video frame-rate stride.\n buffer (bool): Whether to buffer input streams.\n channels (int): Number of image channels (1 for grayscale, 3 for RGB).\n \"\"\"\n torch.backends.cudnn.benchmark = True # faster for fixed-size inference\n self.buffer = buffer # buffer input streams\n self.running = True # running flag for Thread\n self.mode = \"stream\"\n self.vid_stride = vid_stride # video frame-rate stride\n self.cv2_flag = cv2.IMREAD_GRAYSCALE if channels == 1 else cv2.IMREAD_COLOR # grayscale or RGB\n\n sources = Path(sources).read_text().rsplit() if os.path.isfile(sources) else [sources]\n n = len(sources)\n self.bs = n\n self.fps = [0] * n # frames per second\n self.frames = [0] * n\n self.threads = [None] * n\n self.caps = [None] * n # video capture objects\n self.imgs = [[] for _ in range(n)] # images\n self.shape = [[] for _ in range(n)] # image shapes\n self.sources = [ops.clean_str(x).replace(os.sep, \"_\") for x in sources] # clean source names for later\n for i, s in enumerate(sources): # index, source\n # Start thread to read frames from video stream\n st = f\"{i + 1}/{n}: {s}... \"\n if urllib.parse.urlparse(s).hostname in {\"www.youtube.com\", \"youtube.com\", \"youtu.be\"}: # YouTube video\n # YouTube format i.e. 'https://www.youtube.com/watch?v=Jsn8D3aC840' or 'https://youtu.be/Jsn8D3aC840'\n s = get_best_youtube_url(s)\n s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam\n if s == 0 and (IS_COLAB or IS_KAGGLE):\n raise NotImplementedError(\n \"'source=0' webcam not supported in Colab and Kaggle notebooks. \"\n \"Try running 'source=0' in a local environment.\"\n )\n self.caps[i] = cv2.VideoCapture(s) # store video capture object\n if not self.caps[i].isOpened():\n raise ConnectionError(f\"{st}Failed to open {s}\")\n w = int(self.caps[i].get(cv2.CAP_PROP_FRAME_WIDTH))\n h = int(self.caps[i].get(cv2.CAP_PROP_FRAME_HEIGHT))\n fps = self.caps[i].get(cv2.CAP_PROP_FPS) # warning: may return 0 or nan\n self.frames[i] = max(int(self.caps[i].get(cv2.CAP_PROP_FRAME_COUNT)), 0) or float(\n \"inf\"\n ) # infinite stream fallback\n self.fps[i] = max((fps if math.isfinite(fps) else 0) % 100, 0) or 30 # 30 FPS fallback\n\n success, im = self.caps[i].read() # guarantee first frame\n im = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)[..., None] if self.cv2_flag == cv2.IMREAD_GRAYSCALE else im\n if not success or im is None:\n raise ConnectionError(f\"{st}Failed to read images from {s}\")\n self.imgs[i].append(im)\n self.shape[i] = im.shape\n self.threads[i] = Thread(target=self.update, args=([i, self.caps[i], s]), daemon=True)\n LOGGER.info(f\"{st}Success ✅ ({self.frames[i]} frames of shape {w}x{h} at {self.fps[i]:.2f} FPS)\")\n self.threads[i].start()\n LOGGER.info(\"\") # newline\n\n def update(self, i: int, cap: cv2.VideoCapture, stream: str):\n \"\"\"Read stream frames in daemon thread and update image buffer.\"\"\"\n n, f = 0, self.frames[i] # frame number, frame array\n while self.running and cap.isOpened() and n < (f - 1):\n if len(self.imgs[i]) < 30: # keep a <=30-image buffer\n n += 1\n cap.grab() # .read() = .grab() followed by .retrieve()\n if n % self.vid_stride == 0:\n success, im = cap.retrieve()\n im = (\n cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)[..., None] if self.cv2_flag == cv2.IMREAD_GRAYSCALE else im\n )\n if not success:\n im = np.zeros(self.shape[i], dtype=np.uint8)\n LOGGER.warning(\"Video stream unresponsive, please check your IP camera connection.\")\n cap.open(stream) # re-open stream if signal was lost\n if self.buffer:\n self.imgs[i].append(im)\n else:\n self.imgs[i] = [im]\n else:\n time.sleep(0.01) # wait until the buffer is empty\n\n def close(self):\n \"\"\"Terminate stream loader, stop threads, and release video capture resources.\"\"\"\n self.running = False # stop flag for Thread\n for thread in self.threads:\n if thread.is_alive():\n thread.join(timeout=5) # Add timeout\n for cap in self.caps: # Iterate through the stored VideoCapture objects\n try:\n cap.release() # release video capture\n except Exception as e:\n LOGGER.warning(f\"Could not release VideoCapture object: {e}\")\n\n def __iter__(self):\n \"\"\"Iterate through YOLO image feed and re-open unresponsive streams.\"\"\"\n self.count = -1\n return self\n\n def __next__(self) -> Tuple[List[str], List[np.ndarray], List[str]]:\n \"\"\"Return the next batch of frames from multiple video streams for processing.\"\"\"\n self.count += 1\n\n images = []\n for i, x in enumerate(self.imgs):\n # Wait until a frame is available in each buffer\n while not x:\n if not self.threads[i].is_alive():\n self.close()\n raise StopIteration\n time.sleep(1 / min(self.fps))\n x = self.imgs[i]\n if not x:\n LOGGER.warning(f\"Waiting for stream {i}\")\n\n # Get and remove the first frame from imgs buffer\n if self.buffer:\n images.append(x.pop(0))\n\n # Get the last frame, and clear the rest from the imgs buffer\n else:\n images.append(x.pop(-1) if x else np.zeros(self.shape[i], dtype=np.uint8))\n x.clear()\n\n return self.sources, images, [\"\"] * self.bs\n\n def __len__(self) -> int:\n \"\"\"Return the number of video streams in the LoadStreams object.\"\"\"\n return self.bs # 1E12 frames = 32 streams at 30 FPS for 30 years", "chunk_type": "class", "name": "LoadStreams", "file_path": "ultralytics\\ultralytics\\data\\loaders.py", "start_line": 52, "end_line": 224, "start_col": 0, "end_col": 22, "parent_name": null, "docstring": "Stream Loader for various types of video streams.\n\nSupports RTSP, RTMP, HTTP, and TCP streams. This class handles the loading and processing of multiple video\nstreams simultaneously, making it suitable for real-time video analysis tasks.\n\nAttributes:\n sources (List[str]): The source input paths or URLs for the video streams.\n vid_stride (int): Video frame-rate stride.\n buffer (bool): Whether to buffer input streams.\n running (bool): Flag to indicate if the streaming thread is running.\n mode (str): Set to 'stream' indicating real-time capture.\n imgs (List[List[np.ndarray]]): List of image frames for each stream.\n fps (List[float]): List of FPS for each stream.\n frames (List[int]): List of total frames for each stream.\n threads (List[Thread]): List of threads for each stream.\n shape (List[Tuple[int, int, int]]): List of shapes for each stream.\n caps (List[cv2.VideoCapture]): List of cv2.VideoCapture objects for each stream.\n bs (int): Batch size for processing.\n cv2_flag (int): OpenCV flag for image reading (grayscale or RGB).\n\nMethods:\n update: Read stream frames in daemon thread.\n close: Close stream loader and release resources.\n __iter__: Returns an iterator object for the class.\n __next__: Returns source paths, transformed, and original images for processing.\n __len__: Return the length of the sources object.\n\nExamples:\n >>> stream_loader = LoadStreams(\"rtsp://example.com/stream1.mp4\")\n >>> for sources, imgs, _ in stream_loader:\n ... # Process the images\n ... pass\n >>> stream_loader.close()\n\nNotes:\n - The class uses threading to efficiently load frames from multiple streams simultaneously.\n - It automatically handles YouTube links, converting them to the best available stream URL.\n - The class implements a buffer system to manage frame storage and retrieval.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "glob", "math", "os", "time", "urllib", "dataclasses.dataclass", "pathlib.Path", "threading.Thread", "typing.Any", "typing.List", "typing.Optional", "typing.Tuple", "typing.Union", "cv2", "numpy", "torch", "PIL.Image", "ultralytics.data.utils.FORMATS_HELP_MSG", "ultralytics.data.utils.IMG_FORMATS", "ultralytics.data.utils.VID_FORMATS", "ultralytics.utils.IS_COLAB", "ultralytics.utils.IS_KAGGLE", "ultralytics.utils.LOGGER", "ultralytics.utils.ops", "ultralytics.utils.checks.check_requirements", "ultralytics.utils.patches.imread", "mss", "pytubefix.YouTube", "pafy", "yt_dlp", "pi_heif.register_heif_opener" ], "chunk_id": "class_LoadStreams_c9b13d4e" }, { "content": "class LoadScreenshots:\n \"\"\"\n Ultralytics screenshot dataloader for capturing and processing screen images.\n\n This class manages the loading of screenshot images for processing with YOLO. It is suitable for use with\n `yolo predict source=screen`.\n\n Attributes:\n source (str): The source input indicating which screen to capture.\n screen (int): The screen number to capture.\n left (int): The left coordinate for screen capture area.\n top (int): The top coordinate for screen capture area.\n width (int): The width of the screen capture area.\n height (int): The height of the screen capture area.\n mode (str): Set to 'stream' indicating real-time capture.\n frame (int): Counter for captured frames.\n sct (mss.mss): Screen capture object from `mss` library.\n bs (int): Batch size, set to 1.\n fps (int): Frames per second, set to 30.\n monitor (Dict[str, int]): Monitor configuration details.\n cv2_flag (int): OpenCV flag for image reading (grayscale or RGB).\n\n Methods:\n __iter__: Returns an iterator object.\n __next__: Captures the next screenshot and returns it.\n\n Examples:\n >>> loader = LoadScreenshots(\"0 100 100 640 480\") # screen 0, top-left (100,100), 640x480\n >>> for source, im, im0s, vid_cap, s in loader:\n ... print(f\"Captured frame: {im.shape}\")\n \"\"\"\n\n def __init__(self, source: str, channels: int = 3):\n \"\"\"\n Initialize screenshot capture with specified screen and region parameters.\n\n Args:\n source (str): Screen capture source string in format \"screen_num left top width height\".\n channels (int): Number of image channels (1 for grayscale, 3 for RGB).\n \"\"\"\n check_requirements(\"mss\")\n import mss # noqa\n\n source, *params = source.split()\n self.screen, left, top, width, height = 0, None, None, None, None # default to full screen 0\n if len(params) == 1:\n self.screen = int(params[0])\n elif len(params) == 4:\n left, top, width, height = (int(x) for x in params)\n elif len(params) == 5:\n self.screen, left, top, width, height = (int(x) for x in params)\n self.mode = \"stream\"\n self.frame = 0\n self.sct = mss.mss()\n self.bs = 1\n self.fps = 30\n self.cv2_flag = cv2.IMREAD_GRAYSCALE if channels == 1 else cv2.IMREAD_COLOR # grayscale or RGB\n\n # Parse monitor shape\n monitor = self.sct.monitors[self.screen]\n self.top = monitor[\"top\"] if top is None else (monitor[\"top\"] + top)\n self.left = monitor[\"left\"] if left is None else (monitor[\"left\"] + left)\n self.width = width or monitor[\"width\"]\n self.height = height or monitor[\"height\"]\n self.monitor = {\"left\": self.left, \"top\": self.top, \"width\": self.width, \"height\": self.height}\n\n def __iter__(self):\n \"\"\"Yield the next screenshot image from the specified screen or region for processing.\"\"\"\n return self\n\n def __next__(self) -> Tuple[List[str], List[np.ndarray], List[str]]:\n \"\"\"Capture and return the next screenshot as a numpy array using the mss library.\"\"\"\n im0 = np.asarray(self.sct.grab(self.monitor))[:, :, :3] # BGRA to BGR\n im0 = cv2.cvtColor(im0, cv2.COLOR_BGR2GRAY)[..., None] if self.cv2_flag == cv2.IMREAD_GRAYSCALE else im0\n s = f\"screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: \"\n\n self.frame += 1\n return [str(self.screen)], [im0], [s] # screen, img, string", "chunk_type": "class", "name": "LoadScreenshots", "file_path": "ultralytics\\ultralytics\\data\\loaders.py", "start_line": 227, "end_line": 304, "start_col": 0, "end_col": 45, "parent_name": null, "docstring": "Ultralytics screenshot dataloader for capturing and processing screen images.\n\nThis class manages the loading of screenshot images for processing with YOLO. It is suitable for use with\n`yolo predict source=screen`.\n\nAttributes:\n source (str): The source input indicating which screen to capture.\n screen (int): The screen number to capture.\n left (int): The left coordinate for screen capture area.\n top (int): The top coordinate for screen capture area.\n width (int): The width of the screen capture area.\n height (int): The height of the screen capture area.\n mode (str): Set to 'stream' indicating real-time capture.\n frame (int): Counter for captured frames.\n sct (mss.mss): Screen capture object from `mss` library.\n bs (int): Batch size, set to 1.\n fps (int): Frames per second, set to 30.\n monitor (Dict[str, int]): Monitor configuration details.\n cv2_flag (int): OpenCV flag for image reading (grayscale or RGB).\n\nMethods:\n __iter__: Returns an iterator object.\n __next__: Captures the next screenshot and returns it.\n\nExamples:\n >>> loader = LoadScreenshots(\"0 100 100 640 480\") # screen 0, top-left (100,100), 640x480\n >>> for source, im, im0s, vid_cap, s in loader:\n ... print(f\"Captured frame: {im.shape}\")", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "glob", "math", "os", "time", "urllib", "dataclasses.dataclass", "pathlib.Path", "threading.Thread", "typing.Any", "typing.List", "typing.Optional", "typing.Tuple", "typing.Union", "cv2", "numpy", "torch", "PIL.Image", "ultralytics.data.utils.FORMATS_HELP_MSG", "ultralytics.data.utils.IMG_FORMATS", "ultralytics.data.utils.VID_FORMATS", "ultralytics.utils.IS_COLAB", "ultralytics.utils.IS_KAGGLE", "ultralytics.utils.LOGGER", "ultralytics.utils.ops", "ultralytics.utils.checks.check_requirements", "ultralytics.utils.patches.imread", "mss", "pytubefix.YouTube", "pafy", "yt_dlp", "pi_heif.register_heif_opener" ], "chunk_id": "class_LoadScreenshots_a593e08f" }, { "content": "class LoadImagesAndVideos:\n \"\"\"\n A class for loading and processing images and videos for YOLO object detection.\n\n This class manages the loading and pre-processing of image and video data from various sources, including\n single image files, video files, and lists of image and video paths.\n\n Attributes:\n files (List[str]): List of image and video file paths.\n nf (int): Total number of files (images and videos).\n video_flag (List[bool]): Flags indicating whether a file is a video (True) or an image (False).\n mode (str): Current mode, 'image' or 'video'.\n vid_stride (int): Stride for video frame-rate.\n bs (int): Batch size.\n cap (cv2.VideoCapture): Video capture object for OpenCV.\n frame (int): Frame counter for video.\n frames (int): Total number of frames in the video.\n count (int): Counter for iteration, initialized at 0 during __iter__().\n ni (int): Number of images.\n cv2_flag (int): OpenCV flag for image reading (grayscale or RGB).\n\n Methods:\n __init__: Initialize the LoadImagesAndVideos object.\n __iter__: Returns an iterator object for VideoStream or ImageFolder.\n __next__: Returns the next batch of images or video frames along with their paths and metadata.\n _new_video: Creates a new video capture object for the given path.\n __len__: Returns the number of batches in the object.\n\n Examples:\n >>> loader = LoadImagesAndVideos(\"path/to/data\", batch=32, vid_stride=1)\n >>> for paths, imgs, info in loader:\n ... # Process batch of images or video frames\n ... pass\n\n Notes:\n - Supports various image formats including HEIC.\n - Handles both local files and directories.\n - Can read from a text file containing paths to images and videos.\n \"\"\"\n\n def __init__(self, path: Union[str, Path, List], batch: int = 1, vid_stride: int = 1, channels: int = 3):\n \"\"\"\n Initialize dataloader for images and videos, supporting various input formats.\n\n Args:\n path (str | Path | List): Path to images/videos, directory, or list of paths.\n batch (int): Batch size for processing.\n vid_stride (int): Video frame-rate stride.\n channels (int): Number of image channels (1 for grayscale, 3 for RGB).\n \"\"\"\n parent = None\n if isinstance(path, str) and Path(path).suffix == \".txt\": # *.txt file with img/vid/dir on each line\n parent = Path(path).parent\n path = Path(path).read_text().splitlines() # list of sources\n files = []\n for p in sorted(path) if isinstance(path, (list, tuple)) else [path]:\n a = str(Path(p).absolute()) # do not use .resolve() https://github.com/ultralytics/ultralytics/issues/2912\n if \"*\" in a:\n files.extend(sorted(glob.glob(a, recursive=True))) # glob\n elif os.path.isdir(a):\n files.extend(sorted(glob.glob(os.path.join(a, \"*.*\")))) # dir\n elif os.path.isfile(a):\n files.append(a) # files (absolute or relative to CWD)\n elif parent and (parent / p).is_file():\n files.append(str((parent / p).absolute())) # files (relative to *.txt file parent)\n else:\n raise FileNotFoundError(f\"{p} does not exist\")\n\n # Define files as images or videos\n images, videos = [], []\n for f in files:\n suffix = f.rpartition(\".\")[-1].lower() # Get file extension without the dot and lowercase\n if suffix in IMG_FORMATS:\n images.append(f)\n elif suffix in VID_FORMATS:\n videos.append(f)\n ni, nv = len(images), len(videos)\n\n self.files = images + videos\n self.nf = ni + nv # number of files\n self.ni = ni # number of images\n self.video_flag = [False] * ni + [True] * nv\n self.mode = \"video\" if ni == 0 else \"image\" # default to video if no images\n self.vid_stride = vid_stride # video frame-rate stride\n self.bs = batch\n self.cv2_flag = cv2.IMREAD_GRAYSCALE if channels == 1 else cv2.IMREAD_COLOR # grayscale or RGB\n if any(videos):\n self._new_video(videos[0]) # new video\n else:\n self.cap = None\n if self.nf == 0:\n raise FileNotFoundError(f\"No images or videos found in {p}. {FORMATS_HELP_MSG}\")\n\n def __iter__(self):\n \"\"\"Iterate through image/video files, yielding source paths, images, and metadata.\"\"\"\n self.count = 0\n return self\n\n def __next__(self) -> Tuple[List[str], List[np.ndarray], List[str]]:\n \"\"\"Return the next batch of images or video frames with their paths and metadata.\"\"\"\n paths, imgs, info = [], [], []\n while len(imgs) < self.bs:\n if self.count >= self.nf: # end of file list\n if imgs:\n return paths, imgs, info # return last partial batch\n else:\n raise StopIteration\n\n path = self.files[self.count]\n if self.video_flag[self.count]:\n self.mode = \"video\"\n if not self.cap or not self.cap.isOpened():\n self._new_video(path)\n\n success = False\n for _ in range(self.vid_stride):\n success = self.cap.grab()\n if not success:\n break # end of video or failure\n\n if success:\n success, im0 = self.cap.retrieve()\n im0 = (\n cv2.cvtColor(im0, cv2.COLOR_BGR2GRAY)[..., None]\n if self.cv2_flag == cv2.IMREAD_GRAYSCALE\n else im0\n )\n if success:\n self.frame += 1\n paths.append(path)\n imgs.append(im0)\n info.append(f\"video {self.count + 1}/{self.nf} (frame {self.frame}/{self.frames}) {path}: \")\n if self.frame == self.frames: # end of video\n self.count += 1\n self.cap.release()\n else:\n # Move to the next file if the current video ended or failed to open\n self.count += 1\n if self.cap:\n self.cap.release()\n if self.count < self.nf:\n self._new_video(self.files[self.count])\n else:\n # Handle image files (including HEIC)\n self.mode = \"image\"\n if path.rpartition(\".\")[-1].lower() == \"heic\":\n # Load HEIC image using Pillow with pillow-heif\n check_requirements(\"pi-heif\")\n\n from pi_heif import register_heif_opener\n\n register_heif_opener() # Register HEIF opener with Pillow\n with Image.open(path) as img:\n im0 = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR) # convert image to BGR nparray\n else:\n im0 = imread(path, flags=self.cv2_flag) # BGR\n if im0 is None:\n LOGGER.warning(f\"Image Read Error {path}\")\n else:\n paths.append(path)\n imgs.append(im0)\n info.append(f\"image {self.count + 1}/{self.nf} {path}: \")\n self.count += 1 # move to the next file\n if self.count >= self.ni: # end of image list\n break\n\n return paths, imgs, info\n\n def _new_video(self, path: str):\n \"\"\"Create a new video capture object for the given path and initialize video-related attributes.\"\"\"\n self.frame = 0\n self.cap = cv2.VideoCapture(path)\n self.fps = int(self.cap.get(cv2.CAP_PROP_FPS))\n if not self.cap.isOpened():\n raise FileNotFoundError(f\"Failed to open video {path}\")\n self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT) / self.vid_stride)\n\n def __len__(self) -> int:\n \"\"\"Return the number of files (images and videos) in the dataset.\"\"\"\n return math.ceil(self.nf / self.bs) # number of batches", "chunk_type": "class", "name": "LoadImagesAndVideos", "file_path": "ultralytics\\ultralytics\\data\\loaders.py", "start_line": 307, "end_line": 486, "start_col": 0, "end_col": 43, "parent_name": null, "docstring": "A class for loading and processing images and videos for YOLO object detection.\n\nThis class manages the loading and pre-processing of image and video data from various sources, including\nsingle image files, video files, and lists of image and video paths.\n\nAttributes:\n files (List[str]): List of image and video file paths.\n nf (int): Total number of files (images and videos).\n video_flag (List[bool]): Flags indicating whether a file is a video (True) or an image (False).\n mode (str): Current mode, 'image' or 'video'.\n vid_stride (int): Stride for video frame-rate.\n bs (int): Batch size.\n cap (cv2.VideoCapture): Video capture object for OpenCV.\n frame (int): Frame counter for video.\n frames (int): Total number of frames in the video.\n count (int): Counter for iteration, initialized at 0 during __iter__().\n ni (int): Number of images.\n cv2_flag (int): OpenCV flag for image reading (grayscale or RGB).\n\nMethods:\n __init__: Initialize the LoadImagesAndVideos object.\n __iter__: Returns an iterator object for VideoStream or ImageFolder.\n __next__: Returns the next batch of images or video frames along with their paths and metadata.\n _new_video: Creates a new video capture object for the given path.\n __len__: Returns the number of batches in the object.\n\nExamples:\n >>> loader = LoadImagesAndVideos(\"path/to/data\", batch=32, vid_stride=1)\n >>> for paths, imgs, info in loader:\n ... # Process batch of images or video frames\n ... pass\n\nNotes:\n - Supports various image formats including HEIC.\n - Handles both local files and directories.\n - Can read from a text file containing paths to images and videos.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "glob", "math", "os", "time", "urllib", "dataclasses.dataclass", "pathlib.Path", "threading.Thread", "typing.Any", "typing.List", "typing.Optional", "typing.Tuple", "typing.Union", "cv2", "numpy", "torch", "PIL.Image", "ultralytics.data.utils.FORMATS_HELP_MSG", "ultralytics.data.utils.IMG_FORMATS", "ultralytics.data.utils.VID_FORMATS", "ultralytics.utils.IS_COLAB", "ultralytics.utils.IS_KAGGLE", "ultralytics.utils.LOGGER", "ultralytics.utils.ops", "ultralytics.utils.checks.check_requirements", "ultralytics.utils.patches.imread", "mss", "pytubefix.YouTube", "pafy", "yt_dlp", "pi_heif.register_heif_opener" ], "chunk_id": "class_LoadImagesAndVideos_0dea1fb6" }, { "content": "class LoadPilAndNumpy:\n \"\"\"\n Load images from PIL and Numpy arrays for batch processing.\n\n This class manages loading and pre-processing of image data from both PIL and Numpy formats. It performs basic\n validation and format conversion to ensure that the images are in the required format for downstream processing.\n\n Attributes:\n paths (List[str]): List of image paths or autogenerated filenames.\n im0 (List[np.ndarray]): List of images stored as Numpy arrays.\n mode (str): Type of data being processed, set to 'image'.\n bs (int): Batch size, equivalent to the length of `im0`.\n\n Methods:\n _single_check: Validate and format a single image to a Numpy array.\n\n Examples:\n >>> from PIL import Image\n >>> import numpy as np\n >>> pil_img = Image.new(\"RGB\", (100, 100))\n >>> np_img = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)\n >>> loader = LoadPilAndNumpy([pil_img, np_img])\n >>> paths, images, _ = next(iter(loader))\n >>> print(f\"Loaded {len(images)} images\")\n Loaded 2 images\n \"\"\"\n\n def __init__(self, im0: Union[Image.Image, np.ndarray, List], channels: int = 3):\n \"\"\"\n Initialize a loader for PIL and Numpy images, converting inputs to a standardized format.\n\n Args:\n im0 (PIL.Image.Image | np.ndarray | List): Single image or list of images in PIL or numpy format.\n channels (int): Number of image channels (1 for grayscale, 3 for RGB).\n \"\"\"\n if not isinstance(im0, list):\n im0 = [im0]\n # use `image{i}.jpg` when Image.filename returns an empty path.\n self.paths = [getattr(im, \"filename\", \"\") or f\"image{i}.jpg\" for i, im in enumerate(im0)]\n pil_flag = \"L\" if channels == 1 else \"RGB\" # grayscale or RGB\n self.im0 = [self._single_check(im, pil_flag) for im in im0]\n self.mode = \"image\"\n self.bs = len(self.im0)\n\n @staticmethod\n def _single_check(im: Union[Image.Image, np.ndarray], flag: str = \"RGB\") -> np.ndarray:\n \"\"\"Validate and format an image to numpy array, ensuring RGB order and contiguous memory.\"\"\"\n assert isinstance(im, (Image.Image, np.ndarray)), f\"Expected PIL/np.ndarray image type, but got {type(im)}\"\n if isinstance(im, Image.Image):\n im = np.asarray(im.convert(flag))\n # adding new axis if it's grayscale, and converting to BGR if it's RGB\n im = im[..., None] if flag == \"L\" else im[..., ::-1]\n im = np.ascontiguousarray(im) # contiguous\n elif im.ndim == 2: # grayscale in numpy form\n im = im[..., None]\n return im\n\n def __len__(self) -> int:\n \"\"\"Return the length of the 'im0' attribute, representing the number of loaded images.\"\"\"\n return len(self.im0)\n\n def __next__(self) -> Tuple[List[str], List[np.ndarray], List[str]]:\n \"\"\"Return the next batch of images, paths, and metadata for processing.\"\"\"\n if self.count == 1: # loop only once as it's batch inference\n raise StopIteration\n self.count += 1\n return self.paths, self.im0, [\"\"] * self.bs\n\n def __iter__(self):\n \"\"\"Iterate through PIL/numpy images, yielding paths, raw images, and metadata for processing.\"\"\"\n self.count = 0\n return self", "chunk_type": "class", "name": "LoadPilAndNumpy", "file_path": "ultralytics\\ultralytics\\data\\loaders.py", "start_line": 489, "end_line": 560, "start_col": 0, "end_col": 19, "parent_name": null, "docstring": "Load images from PIL and Numpy arrays for batch processing.\n\nThis class manages loading and pre-processing of image data from both PIL and Numpy formats. It performs basic\nvalidation and format conversion to ensure that the images are in the required format for downstream processing.\n\nAttributes:\n paths (List[str]): List of image paths or autogenerated filenames.\n im0 (List[np.ndarray]): List of images stored as Numpy arrays.\n mode (str): Type of data being processed, set to 'image'.\n bs (int): Batch size, equivalent to the length of `im0`.\n\nMethods:\n _single_check: Validate and format a single image to a Numpy array.\n\nExamples:\n >>> from PIL import Image\n >>> import numpy as np\n >>> pil_img = Image.new(\"RGB\", (100, 100))\n >>> np_img = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)\n >>> loader = LoadPilAndNumpy([pil_img, np_img])\n >>> paths, images, _ = next(iter(loader))\n >>> print(f\"Loaded {len(images)} images\")\n Loaded 2 images", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "glob", "math", "os", "time", "urllib", "dataclasses.dataclass", "pathlib.Path", "threading.Thread", "typing.Any", "typing.List", "typing.Optional", "typing.Tuple", "typing.Union", "cv2", "numpy", "torch", "PIL.Image", "ultralytics.data.utils.FORMATS_HELP_MSG", "ultralytics.data.utils.IMG_FORMATS", "ultralytics.data.utils.VID_FORMATS", "ultralytics.utils.IS_COLAB", "ultralytics.utils.IS_KAGGLE", "ultralytics.utils.LOGGER", "ultralytics.utils.ops", "ultralytics.utils.checks.check_requirements", "ultralytics.utils.patches.imread", "mss", "pytubefix.YouTube", "pafy", "yt_dlp", "pi_heif.register_heif_opener" ], "chunk_id": "class_LoadPilAndNumpy_7b049b8b" }, { "content": "class LoadTensor:\n \"\"\"\n A class for loading and processing tensor data for object detection tasks.\n\n This class handles the loading and pre-processing of image data from PyTorch tensors, preparing them for\n further processing in object detection pipelines.\n\n Attributes:\n im0 (torch.Tensor): The input tensor containing the image(s) with shape (B, C, H, W).\n bs (int): Batch size, inferred from the shape of `im0`.\n mode (str): Current processing mode, set to 'image'.\n paths (List[str]): List of image paths or auto-generated filenames.\n\n Methods:\n _single_check: Validates and formats an input tensor.\n\n Examples:\n >>> import torch\n >>> tensor = torch.rand(1, 3, 640, 640)\n >>> loader = LoadTensor(tensor)\n >>> paths, images, info = next(iter(loader))\n >>> print(f\"Processed {len(images)} images\")\n \"\"\"\n\n def __init__(self, im0: torch.Tensor) -> None:\n \"\"\"\n Initialize LoadTensor object for processing torch.Tensor image data.\n\n Args:\n im0 (torch.Tensor): Input tensor with shape (B, C, H, W).\n \"\"\"\n self.im0 = self._single_check(im0)\n self.bs = self.im0.shape[0]\n self.mode = \"image\"\n self.paths = [getattr(im, \"filename\", f\"image{i}.jpg\") for i, im in enumerate(im0)]\n\n @staticmethod\n def _single_check(im: torch.Tensor, stride: int = 32) -> torch.Tensor:\n \"\"\"Validate and format a single image tensor, ensuring correct shape and normalization.\"\"\"\n s = (\n f\"torch.Tensor inputs should be BCHW i.e. shape(1, 3, 640, 640) \"\n f\"divisible by stride {stride}. Input shape{tuple(im.shape)} is incompatible.\"\n )\n if len(im.shape) != 4:\n if len(im.shape) != 3:\n raise ValueError(s)\n LOGGER.warning(s)\n im = im.unsqueeze(0)\n if im.shape[2] % stride or im.shape[3] % stride:\n raise ValueError(s)\n if im.max() > 1.0 + torch.finfo(im.dtype).eps: # torch.float32 eps is 1.2e-07\n LOGGER.warning(\n f\"torch.Tensor inputs should be normalized 0.0-1.0 but max value is {im.max()}. Dividing input by 255.\"\n )\n im = im.float() / 255.0\n\n return im\n\n def __iter__(self):\n \"\"\"Yield an iterator object for iterating through tensor image data.\"\"\"\n self.count = 0\n return self\n\n def __next__(self) -> Tuple[List[str], torch.Tensor, List[str]]:\n \"\"\"Yield the next batch of tensor images and metadata for processing.\"\"\"\n if self.count == 1:\n raise StopIteration\n self.count += 1\n return self.paths, self.im0, [\"\"] * self.bs\n\n def __len__(self) -> int:\n \"\"\"Return the batch size of the tensor input.\"\"\"\n return self.bs", "chunk_type": "class", "name": "LoadTensor", "file_path": "ultralytics\\ultralytics\\data\\loaders.py", "start_line": 563, "end_line": 635, "start_col": 0, "end_col": 22, "parent_name": null, "docstring": "A class for loading and processing tensor data for object detection tasks.\n\nThis class handles the loading and pre-processing of image data from PyTorch tensors, preparing them for\nfurther processing in object detection pipelines.\n\nAttributes:\n im0 (torch.Tensor): The input tensor containing the image(s) with shape (B, C, H, W).\n bs (int): Batch size, inferred from the shape of `im0`.\n mode (str): Current processing mode, set to 'image'.\n paths (List[str]): List of image paths or auto-generated filenames.\n\nMethods:\n _single_check: Validates and formats an input tensor.\n\nExamples:\n >>> import torch\n >>> tensor = torch.rand(1, 3, 640, 640)\n >>> loader = LoadTensor(tensor)\n >>> paths, images, info = next(iter(loader))\n >>> print(f\"Processed {len(images)} images\")", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "glob", "math", "os", "time", "urllib", "dataclasses.dataclass", "pathlib.Path", "threading.Thread", "typing.Any", "typing.List", "typing.Optional", "typing.Tuple", "typing.Union", "cv2", "numpy", "torch", "PIL.Image", "ultralytics.data.utils.FORMATS_HELP_MSG", "ultralytics.data.utils.IMG_FORMATS", "ultralytics.data.utils.VID_FORMATS", "ultralytics.utils.IS_COLAB", "ultralytics.utils.IS_KAGGLE", "ultralytics.utils.LOGGER", "ultralytics.utils.ops", "ultralytics.utils.checks.check_requirements", "ultralytics.utils.patches.imread", "mss", "pytubefix.YouTube", "pafy", "yt_dlp", "pi_heif.register_heif_opener" ], "chunk_id": "class_LoadTensor_0482392f" }, { "content": "def autocast_list(source: List[Any]) -> List[Union[Image.Image, np.ndarray]]:\n \"\"\"Merge a list of sources into a list of numpy arrays or PIL images for Ultralytics prediction.\"\"\"\n files = []\n for im in source:\n if isinstance(im, (str, Path)): # filename or uri\n files.append(Image.open(urllib.request.urlopen(im) if str(im).startswith(\"http\") else im))\n elif isinstance(im, (Image.Image, np.ndarray)): # PIL or np Image\n files.append(im)\n else:\n raise TypeError(\n f\"type {type(im).__name__} is not a supported Ultralytics prediction source type. \\n\"\n f\"See https://docs.ultralytics.com/modes/predict for supported source types.\"\n )\n\n return files", "chunk_type": "function", "name": "autocast_list", "file_path": "ultralytics\\ultralytics\\data\\loaders.py", "start_line": 638, "end_line": 652, "start_col": 0, "end_col": 16, "parent_name": null, "docstring": "Merge a list of sources into a list of numpy arrays or PIL images for Ultralytics prediction.", "parameters": [ "source: List[Any]" ], "return_type": "List[Union[Image.Image, np.ndarray]]", "decorators": [], "complexity_score": 4, "dependencies": [ "glob", "math", "os", "time", "urllib", "dataclasses.dataclass", "pathlib.Path", "threading.Thread", "typing.Any", "typing.List", "typing.Optional", "typing.Tuple", "typing.Union", "cv2", "numpy", "torch", "PIL.Image", "ultralytics.data.utils.FORMATS_HELP_MSG", "ultralytics.data.utils.IMG_FORMATS", "ultralytics.data.utils.VID_FORMATS", "ultralytics.utils.IS_COLAB", "ultralytics.utils.IS_KAGGLE", "ultralytics.utils.LOGGER", "ultralytics.utils.ops", "ultralytics.utils.checks.check_requirements", "ultralytics.utils.patches.imread", "mss", "pytubefix.YouTube", "pafy", "yt_dlp", "pi_heif.register_heif_opener" ], "chunk_id": "function_autocast_list_3fd4e65d" }, { "content": "def get_best_youtube_url(url: str, method: str = \"pytube\") -> Optional[str]:\n \"\"\"\n Retrieve the URL of the best quality MP4 video stream from a given YouTube video.\n\n Args:\n url (str): The URL of the YouTube video.\n method (str): The method to use for extracting video info. Options are \"pytube\", \"pafy\", and \"yt-dlp\".\n\n Returns:\n (str | None): The URL of the best quality MP4 video stream, or None if no suitable stream is found.\n\n Examples:\n >>> url = \"https://www.youtube.com/watch?v=dQw4w9WgXcQ\"\n >>> best_url = get_best_youtube_url(url)\n >>> print(best_url)\n https://rr4---sn-q4flrnek.googlevideo.com/videoplayback?expire=...\n\n Notes:\n - Requires additional libraries based on the chosen method: pytubefix, pafy, or yt-dlp.\n - The function prioritizes streams with at least 1080p resolution when available.\n - For the \"yt-dlp\" method, it looks for formats with video codec, no audio, and *.mp4 extension.\n \"\"\"\n if method == \"pytube\":\n # Switched from pytube to pytubefix to resolve https://github.com/pytube/pytube/issues/1954\n check_requirements(\"pytubefix>=6.5.2\")\n from pytubefix import YouTube\n\n streams = YouTube(url).streams.filter(file_extension=\"mp4\", only_video=True)\n streams = sorted(streams, key=lambda s: s.resolution, reverse=True) # sort streams by resolution\n for stream in streams:\n if stream.resolution and int(stream.resolution[:-1]) >= 1080: # check if resolution is at least 1080p\n return stream.url\n\n elif method == \"pafy\":\n check_requirements((\"pafy\", \"youtube_dl==2020.12.2\"))\n import pafy # noqa\n\n return pafy.new(url).getbestvideo(preftype=\"mp4\").url\n\n elif method == \"yt-dlp\":\n check_requirements(\"yt-dlp\")\n import yt_dlp\n\n with yt_dlp.YoutubeDL({\"quiet\": True}) as ydl:\n info_dict = ydl.extract_info(url, download=False) # extract info\n for f in reversed(info_dict.get(\"formats\", [])): # reversed because best is usually last\n # Find a format with video codec, no audio, *.mp4 extension at least 1920x1080 size\n good_size = (f.get(\"width\") or 0) >= 1920 or (f.get(\"height\") or 0) >= 1080\n if good_size and f[\"vcodec\"] != \"none\" and f[\"acodec\"] == \"none\" and f[\"ext\"] == \"mp4\":\n return f.get(\"url\")", "chunk_type": "function", "name": "get_best_youtube_url", "file_path": "ultralytics\\ultralytics\\data\\loaders.py", "start_line": 655, "end_line": 704, "start_col": 0, "end_col": 35, "parent_name": null, "docstring": "Retrieve the URL of the best quality MP4 video stream from a given YouTube video.\n\nArgs:\n url (str): The URL of the YouTube video.\n method (str): The method to use for extracting video info. Options are \"pytube\", \"pafy\", and \"yt-dlp\".\n\nReturns:\n (str | None): The URL of the best quality MP4 video stream, or None if no suitable stream is found.\n\nExamples:\n >>> url = \"https://www.youtube.com/watch?v=dQw4w9WgXcQ\"\n >>> best_url = get_best_youtube_url(url)\n >>> print(best_url)\n https://rr4---sn-q4flrnek.googlevideo.com/videoplayback?expire=...\n\nNotes:\n - Requires additional libraries based on the chosen method: pytubefix, pafy, or yt-dlp.\n - The function prioritizes streams with at least 1080p resolution when available.\n - For the \"yt-dlp\" method, it looks for formats with video codec, no audio, and *.mp4 extension.", "parameters": [ "url: str", "method: str" ], "return_type": "Optional[str]", "decorators": [], "complexity_score": 8, "dependencies": [ "glob", "math", "os", "time", "urllib", "dataclasses.dataclass", "pathlib.Path", "threading.Thread", "typing.Any", "typing.List", "typing.Optional", "typing.Tuple", "typing.Union", "cv2", "numpy", "torch", "PIL.Image", "ultralytics.data.utils.FORMATS_HELP_MSG", "ultralytics.data.utils.IMG_FORMATS", "ultralytics.data.utils.VID_FORMATS", "ultralytics.utils.IS_COLAB", "ultralytics.utils.IS_KAGGLE", "ultralytics.utils.LOGGER", "ultralytics.utils.ops", "ultralytics.utils.checks.check_requirements", "ultralytics.utils.patches.imread", "mss", "pytubefix.YouTube", "pafy", "yt_dlp", "pi_heif.register_heif_opener" ], "chunk_id": "function_get_best_youtube_url_024e34e0" }, { "content": "LOADERS = (LoadStreams, LoadPilAndNumpy, LoadImagesAndVideos, LoadScreenshots)", "chunk_type": "variable", "name": "LOADERS", "file_path": "ultralytics\\ultralytics\\data\\loaders.py", "start_line": 708, "end_line": 708, "start_col": 0, "end_col": 78, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_LOADERS_0343eaec" }, { "content": "import random", "chunk_type": "import", "name": "random", "file_path": "ultralytics\\ultralytics\\data\\split.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 13, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_random_7445cd18" }, { "content": "import shutil", "chunk_type": "import", "name": "shutil", "file_path": "ultralytics\\ultralytics\\data\\split.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 13, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_shutil_c6fd2ac7" }, { "content": "from pathlib import Path", "chunk_type": "import", "name": "Path", "file_path": "ultralytics\\ultralytics\\data\\split.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Path_b212a9e5" }, { "content": "from typing import Tuple, Union", "chunk_type": "import", "name": "Tuple, Union", "file_path": "ultralytics\\ultralytics\\data\\split.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 31, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Tuple, Union_3004f54a" }, { "content": "from ultralytics.data.utils import IMG_FORMATS, img2label_paths", "chunk_type": "import", "name": "IMG_FORMATS, img2label_paths", "file_path": "ultralytics\\ultralytics\\data\\split.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 63, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_IMG_FORMATS, img2label_paths_d8a8f9df" }, { "content": "from ultralytics.utils import DATASETS_DIR, LOGGER, TQDM", "chunk_type": "import", "name": "DATASETS_DIR, LOGGER, TQDM", "file_path": "ultralytics\\ultralytics\\data\\split.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 56, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_DATASETS_DIR, LOGGER, TQDM_b32b8d0b" }, { "content": "def split_classify_dataset(source_dir: Union[str, Path], train_ratio: float = 0.8) -> Path:\n \"\"\"\n Split classification dataset into train and val directories in a new directory.\n\n Creates a new directory '{source_dir}_split' with train/val subdirectories, preserving the original class\n structure with an 80/20 split by default.\n\n Directory structure:\n Before:\n caltech/\n ├── class1/\n │ ├── img1.jpg\n │ ├── img2.jpg\n │ └── ...\n ├── class2/\n │ ├── img1.jpg\n │ └── ...\n └── ...\n\n After:\n caltech_split/\n ├── train/\n │ ├── class1/\n │ │ ├── img1.jpg\n │ │ └── ...\n │ ├── class2/\n │ │ ├── img1.jpg\n │ │ └── ...\n │ └── ...\n └── val/\n ├── class1/\n │ ├── img2.jpg\n │ └── ...\n ├── class2/\n │ └── ...\n └── ...\n\n Args:\n source_dir (str | Path): Path to classification dataset root directory.\n train_ratio (float): Ratio for train split, between 0 and 1.\n\n Returns:\n (Path): Path to the created split directory.\n\n Examples:\n Split dataset with default 80/20 ratio\n >>> split_classify_dataset(\"path/to/caltech\")\n\n Split with custom ratio\n >>> split_classify_dataset(\"path/to/caltech\", 0.75)\n \"\"\"\n source_path = Path(source_dir)\n split_path = Path(f\"{source_path}_split\")\n train_path, val_path = split_path / \"train\", split_path / \"val\"\n\n # Create directory structure\n split_path.mkdir(exist_ok=True)\n train_path.mkdir(exist_ok=True)\n val_path.mkdir(exist_ok=True)\n\n # Process class directories\n class_dirs = [d for d in source_path.iterdir() if d.is_dir()]\n total_images = sum(len(list(d.glob(\"*.*\"))) for d in class_dirs)\n stats = f\"{len(class_dirs)} classes, {total_images} images\"\n LOGGER.info(f\"Splitting {source_path} ({stats}) into {train_ratio:.0%} train, {1 - train_ratio:.0%} val...\")\n\n for class_dir in class_dirs:\n # Create class directories\n (train_path / class_dir.name).mkdir(exist_ok=True)\n (val_path / class_dir.name).mkdir(exist_ok=True)\n\n # Split and copy files\n image_files = list(class_dir.glob(\"*.*\"))\n random.shuffle(image_files)\n split_idx = int(len(image_files) * train_ratio)\n\n for img in image_files[:split_idx]:\n shutil.copy2(img, train_path / class_dir.name / img.name)\n\n for img in image_files[split_idx:]:\n shutil.copy2(img, val_path / class_dir.name / img.name)\n\n LOGGER.info(f\"Split complete in {split_path} ✅\")\n return split_path", "chunk_type": "function", "name": "split_classify_dataset", "file_path": "ultralytics\\ultralytics\\data\\split.py", "start_line": 12, "end_line": 95, "start_col": 0, "end_col": 21, "parent_name": null, "docstring": "Split classification dataset into train and val directories in a new directory.\n\nCreates a new directory '{source_dir}_split' with train/val subdirectories, preserving the original class\nstructure with an 80/20 split by default.\n\nDirectory structure:\n Before:\n caltech/\n ├── class1/\n │ ├── img1.jpg\n │ ├── img2.jpg\n │ └── ...\n ├── class2/\n │ ├── img1.jpg\n │ └── ...\n └── ...\n\n After:\n caltech_split/\n ├── train/\n │ ├── class1/\n │ │ ├── img1.jpg\n │ │ └── ...\n │ ├── class2/\n │ │ ├── img1.jpg\n │ │ └── ...\n │ └── ...\n └── val/\n ├── class1/\n │ ├── img2.jpg\n │ └── ...\n ├── class2/\n │ └── ...\n └── ...\n\nArgs:\n source_dir (str | Path): Path to classification dataset root directory.\n train_ratio (float): Ratio for train split, between 0 and 1.\n\nReturns:\n (Path): Path to the created split directory.\n\nExamples:\n Split dataset with default 80/20 ratio\n >>> split_classify_dataset(\"path/to/caltech\")\n\n Split with custom ratio\n >>> split_classify_dataset(\"path/to/caltech\", 0.75)", "parameters": [ "source_dir: Union[str, Path]", "train_ratio: float" ], "return_type": "Path", "decorators": [], "complexity_score": 6, "dependencies": [ "random", "shutil", "pathlib.Path", "typing.Tuple", "typing.Union", "ultralytics.data.utils.IMG_FORMATS", "ultralytics.data.utils.img2label_paths", "ultralytics.utils.DATASETS_DIR", "ultralytics.utils.LOGGER", "ultralytics.utils.TQDM" ], "chunk_id": "function_split_classify_dataset_b767a475" }, { "content": "def autosplit(\n path: Path = DATASETS_DIR / \"coco8/images\",\n weights: Tuple[float, float, float] = (0.9, 0.1, 0.0),\n annotated_only: bool = False,\n) -> None:\n \"\"\"\n Automatically split a dataset into train/val/test splits and save the resulting splits into autosplit_*.txt files.\n\n Args:\n path (Path): Path to images directory.\n weights (tuple): Train, validation, and test split fractions.\n annotated_only (bool): If True, only images with an associated txt file are used.\n\n Examples:\n Split images with default weights\n >>> from ultralytics.data.split import autosplit\n >>> autosplit()\n\n Split with custom weights and annotated images only\n >>> autosplit(path=\"path/to/images\", weights=(0.8, 0.15, 0.05), annotated_only=True)\n \"\"\"\n path = Path(path) # images dir\n files = sorted(x for x in path.rglob(\"*.*\") if x.suffix[1:].lower() in IMG_FORMATS) # image files only\n n = len(files) # number of files\n random.seed(0) # for reproducibility\n indices = random.choices([0, 1, 2], weights=weights, k=n) # assign each image to a split\n\n txt = [\"autosplit_train.txt\", \"autosplit_val.txt\", \"autosplit_test.txt\"] # 3 txt files\n for x in txt:\n if (path.parent / x).exists():\n (path.parent / x).unlink() # remove existing\n\n LOGGER.info(f\"Autosplitting images from {path}\" + \", using *.txt labeled images only\" * annotated_only)\n for i, img in TQDM(zip(indices, files), total=n):\n if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label\n with open(path.parent / txt[i], \"a\", encoding=\"utf-8\") as f:\n f.write(f\"./{img.relative_to(path.parent).as_posix()}\" + \"\\n\") # add image to txt file", "chunk_type": "function", "name": "autosplit", "file_path": "ultralytics\\ultralytics\\data\\split.py", "start_line": 98, "end_line": 134, "start_col": 0, "end_col": 78, "parent_name": null, "docstring": "Automatically split a dataset into train/val/test splits and save the resulting splits into autosplit_*.txt files.\n\nArgs:\n path (Path): Path to images directory.\n weights (tuple): Train, validation, and test split fractions.\n annotated_only (bool): If True, only images with an associated txt file are used.\n\nExamples:\n Split images with default weights\n >>> from ultralytics.data.split import autosplit\n >>> autosplit()\n\n Split with custom weights and annotated images only\n >>> autosplit(path=\"path/to/images\", weights=(0.8, 0.15, 0.05), annotated_only=True)", "parameters": [ "path: Path", "weights: Tuple[float, float, float]", "annotated_only: bool" ], "return_type": "None", "decorators": [], "complexity_score": 6, "dependencies": [ "random", "shutil", "pathlib.Path", "typing.Tuple", "typing.Union", "ultralytics.data.utils.IMG_FORMATS", "ultralytics.data.utils.img2label_paths", "ultralytics.utils.DATASETS_DIR", "ultralytics.utils.LOGGER", "ultralytics.utils.TQDM" ], "chunk_id": "function_autosplit_159af2a9" }, { "content": "import itertools", "chunk_type": "import", "name": "itertools", "file_path": "ultralytics\\ultralytics\\data\\split_dota.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 16, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_itertools_872ee474" }, { "content": "from glob import glob", "chunk_type": "import", "name": "glob", "file_path": "ultralytics\\ultralytics\\data\\split_dota.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 21, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_glob_398e8009" }, { "content": "from math import ceil", "chunk_type": "import", "name": "ceil", "file_path": "ultralytics\\ultralytics\\data\\split_dota.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 21, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_ceil_58b1d94a" }, { "content": "from pathlib import Path", "chunk_type": "import", "name": "Path", "file_path": "ultralytics\\ultralytics\\data\\split_dota.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Path_a3dad259" }, { "content": "from typing import Any, Dict, List, Tuple", "chunk_type": "import", "name": "Any, Dict, List, Tuple", "file_path": "ultralytics\\ultralytics\\data\\split_dota.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 41, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Any, Dict, List, Tuple_dc3e5a5f" }, { "content": "import cv2", "chunk_type": "import", "name": "cv2", "file_path": "ultralytics\\ultralytics\\data\\split_dota.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 10, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_cv2_0c011e8a" }, { "content": "import numpy as np", "chunk_type": "import", "name": "numpy", "file_path": "ultralytics\\ultralytics\\data\\split_dota.py", "start_line": 10, "end_line": 10, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_numpy_8aeadfdd" }, { "content": "from PIL import Image", "chunk_type": "import", "name": "Image", "file_path": "ultralytics\\ultralytics\\data\\split_dota.py", "start_line": 11, "end_line": 11, "start_col": 0, "end_col": 21, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Image_32953568" }, { "content": "from ultralytics.data.utils import exif_size, img2label_paths", "chunk_type": "import", "name": "exif_size, img2label_paths", "file_path": "ultralytics\\ultralytics\\data\\split_dota.py", "start_line": 13, "end_line": 13, "start_col": 0, "end_col": 61, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_exif_size, img2label_paths_0854ddd7" }, { "content": "from ultralytics.utils import TQDM", "chunk_type": "import", "name": "TQDM", "file_path": "ultralytics\\ultralytics\\data\\split_dota.py", "start_line": 14, "end_line": 14, "start_col": 0, "end_col": 34, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_TQDM_0a204744" }, { "content": "from ultralytics.utils.checks import check_requirements", "chunk_type": "import", "name": "check_requirements", "file_path": "ultralytics\\ultralytics\\data\\split_dota.py", "start_line": 15, "end_line": 15, "start_col": 0, "end_col": 55, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_check_requirements_ce3d5e7e" }, { "content": "def bbox_iof(polygon1: np.ndarray, bbox2: np.ndarray, eps: float = 1e-6) -> np.ndarray:\n \"\"\"\n Calculate Intersection over Foreground (IoF) between polygons and bounding boxes.\n\n Args:\n polygon1 (np.ndarray): Polygon coordinates with shape (N, 8).\n bbox2 (np.ndarray): Bounding boxes with shape (N, 4).\n eps (float, optional): Small value to prevent division by zero.\n\n Returns:\n (np.ndarray): IoF scores with shape (N, 1) or (N, M) if bbox2 is (M, 4).\n\n Notes:\n Polygon format: [x1, y1, x2, y2, x3, y3, x4, y4].\n Bounding box format: [x_min, y_min, x_max, y_max].\n \"\"\"\n check_requirements(\"shapely>=2.0.0\")\n from shapely.geometry import Polygon\n\n polygon1 = polygon1.reshape(-1, 4, 2)\n lt_point = np.min(polygon1, axis=-2) # left-top\n rb_point = np.max(polygon1, axis=-2) # right-bottom\n bbox1 = np.concatenate([lt_point, rb_point], axis=-1)\n\n lt = np.maximum(bbox1[:, None, :2], bbox2[..., :2])\n rb = np.minimum(bbox1[:, None, 2:], bbox2[..., 2:])\n wh = np.clip(rb - lt, 0, np.inf)\n h_overlaps = wh[..., 0] * wh[..., 1]\n\n left, top, right, bottom = (bbox2[..., i] for i in range(4))\n polygon2 = np.stack([left, top, right, top, right, bottom, left, bottom], axis=-1).reshape(-1, 4, 2)\n\n sg_polys1 = [Polygon(p) for p in polygon1]\n sg_polys2 = [Polygon(p) for p in polygon2]\n overlaps = np.zeros(h_overlaps.shape)\n for p in zip(*np.nonzero(h_overlaps)):\n overlaps[p] = sg_polys1[p[0]].intersection(sg_polys2[p[-1]]).area\n unions = np.array([p.area for p in sg_polys1], dtype=np.float32)\n unions = unions[..., None]\n\n unions = np.clip(unions, eps, np.inf)\n outputs = overlaps / unions\n if outputs.ndim == 1:\n outputs = outputs[..., None]\n return outputs", "chunk_type": "function", "name": "bbox_iof", "file_path": "ultralytics\\ultralytics\\data\\split_dota.py", "start_line": 18, "end_line": 62, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": "Calculate Intersection over Foreground (IoF) between polygons and bounding boxes.\n\nArgs:\n polygon1 (np.ndarray): Polygon coordinates with shape (N, 8).\n bbox2 (np.ndarray): Bounding boxes with shape (N, 4).\n eps (float, optional): Small value to prevent division by zero.\n\nReturns:\n (np.ndarray): IoF scores with shape (N, 1) or (N, M) if bbox2 is (M, 4).\n\nNotes:\n Polygon format: [x1, y1, x2, y2, x3, y3, x4, y4].\n Bounding box format: [x_min, y_min, x_max, y_max].", "parameters": [ "polygon1: np.ndarray", "bbox2: np.ndarray", "eps: float" ], "return_type": "np.ndarray", "decorators": [], "complexity_score": 7, "dependencies": [ "itertools", "glob.glob", "math.ceil", "pathlib.Path", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "cv2", "numpy", "PIL.Image", "ultralytics.data.utils.exif_size", "ultralytics.data.utils.img2label_paths", "ultralytics.utils.TQDM", "ultralytics.utils.checks.check_requirements", "shapely.geometry.Polygon" ], "chunk_id": "function_bbox_iof_6c20e25b" }, { "content": "def load_yolo_dota(data_root: str, split: str = \"train\") -> List[Dict[str, Any]]:\n \"\"\"\n Load DOTA dataset annotations and image information.\n\n Args:\n data_root (str): Data root directory.\n split (str, optional): The split data set, could be 'train' or 'val'.\n\n Returns:\n (List[Dict[str, Any]]): List of annotation dictionaries containing image information.\n\n Notes:\n The directory structure assumed for the DOTA dataset:\n - data_root\n - images\n - train\n - val\n - labels\n - train\n - val\n \"\"\"\n assert split in {\"train\", \"val\"}, f\"Split must be 'train' or 'val', not {split}.\"\n im_dir = Path(data_root) / \"images\" / split\n assert im_dir.exists(), f\"Can't find {im_dir}, please check your data root.\"\n im_files = glob(str(Path(data_root) / \"images\" / split / \"*\"))\n lb_files = img2label_paths(im_files)\n annos = []\n for im_file, lb_file in zip(im_files, lb_files):\n w, h = exif_size(Image.open(im_file))\n with open(lb_file, encoding=\"utf-8\") as f:\n lb = [x.split() for x in f.read().strip().splitlines() if len(x)]\n lb = np.array(lb, dtype=np.float32)\n annos.append(dict(ori_size=(h, w), label=lb, filepath=im_file))\n return annos", "chunk_type": "function", "name": "load_yolo_dota", "file_path": "ultralytics\\ultralytics\\data\\split_dota.py", "start_line": 65, "end_line": 98, "start_col": 0, "end_col": 16, "parent_name": null, "docstring": "Load DOTA dataset annotations and image information.\n\nArgs:\n data_root (str): Data root directory.\n split (str, optional): The split data set, could be 'train' or 'val'.\n\nReturns:\n (List[Dict[str, Any]]): List of annotation dictionaries containing image information.\n\nNotes:\n The directory structure assumed for the DOTA dataset:\n - data_root\n - images\n - train\n - val\n - labels\n - train\n - val", "parameters": [ "data_root: str", "split: str" ], "return_type": "List[Dict[str, Any]]", "decorators": [], "complexity_score": 3, "dependencies": [ "itertools", "glob.glob", "math.ceil", "pathlib.Path", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "cv2", "numpy", "PIL.Image", "ultralytics.data.utils.exif_size", "ultralytics.data.utils.img2label_paths", "ultralytics.utils.TQDM", "ultralytics.utils.checks.check_requirements", "shapely.geometry.Polygon" ], "chunk_id": "function_load_yolo_dota_4eb22b7d" }, { "content": "def get_windows(\n im_size: Tuple[int, int],\n crop_sizes: Tuple[int, ...] = (1024,),\n gaps: Tuple[int, ...] = (200,),\n im_rate_thr: float = 0.6,\n eps: float = 0.01,\n) -> np.ndarray:\n \"\"\"\n Get the coordinates of sliding windows for image cropping.\n\n Args:\n im_size (Tuple[int, int]): Original image size, (H, W).\n crop_sizes (Tuple[int, ...], optional): Crop size of windows.\n gaps (Tuple[int, ...], optional): Gap between crops.\n im_rate_thr (float, optional): Threshold of windows areas divided by image areas.\n eps (float, optional): Epsilon value for math operations.\n\n Returns:\n (np.ndarray): Array of window coordinates with shape (N, 4) where each row is [x_start, y_start, x_stop, y_stop].\n \"\"\"\n h, w = im_size\n windows = []\n for crop_size, gap in zip(crop_sizes, gaps):\n assert crop_size > gap, f\"invalid crop_size gap pair [{crop_size} {gap}]\"\n step = crop_size - gap\n\n xn = 1 if w <= crop_size else ceil((w - crop_size) / step + 1)\n xs = [step * i for i in range(xn)]\n if len(xs) > 1 and xs[-1] + crop_size > w:\n xs[-1] = w - crop_size\n\n yn = 1 if h <= crop_size else ceil((h - crop_size) / step + 1)\n ys = [step * i for i in range(yn)]\n if len(ys) > 1 and ys[-1] + crop_size > h:\n ys[-1] = h - crop_size\n\n start = np.array(list(itertools.product(xs, ys)), dtype=np.int64)\n stop = start + crop_size\n windows.append(np.concatenate([start, stop], axis=1))\n windows = np.concatenate(windows, axis=0)\n\n im_in_wins = windows.copy()\n im_in_wins[:, 0::2] = np.clip(im_in_wins[:, 0::2], 0, w)\n im_in_wins[:, 1::2] = np.clip(im_in_wins[:, 1::2], 0, h)\n im_areas = (im_in_wins[:, 2] - im_in_wins[:, 0]) * (im_in_wins[:, 3] - im_in_wins[:, 1])\n win_areas = (windows[:, 2] - windows[:, 0]) * (windows[:, 3] - windows[:, 1])\n im_rates = im_areas / win_areas\n if not (im_rates > im_rate_thr).any():\n max_rate = im_rates.max()\n im_rates[abs(im_rates - max_rate) < eps] = 1\n return windows[im_rates > im_rate_thr]", "chunk_type": "function", "name": "get_windows", "file_path": "ultralytics\\ultralytics\\data\\split_dota.py", "start_line": 101, "end_line": 151, "start_col": 0, "end_col": 42, "parent_name": null, "docstring": "Get the coordinates of sliding windows for image cropping.\n\nArgs:\n im_size (Tuple[int, int]): Original image size, (H, W).\n crop_sizes (Tuple[int, ...], optional): Crop size of windows.\n gaps (Tuple[int, ...], optional): Gap between crops.\n im_rate_thr (float, optional): Threshold of windows areas divided by image areas.\n eps (float, optional): Epsilon value for math operations.\n\nReturns:\n (np.ndarray): Array of window coordinates with shape (N, 4) where each row is [x_start, y_start, x_stop, y_stop].", "parameters": [ "im_size: Tuple[int, int]", "crop_sizes: Tuple[int, ...]", "gaps: Tuple[int, ...]", "im_rate_thr: float", "eps: float" ], "return_type": "np.ndarray", "decorators": [], "complexity_score": 7, "dependencies": [ "itertools", "glob.glob", "math.ceil", "pathlib.Path", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "cv2", "numpy", "PIL.Image", "ultralytics.data.utils.exif_size", "ultralytics.data.utils.img2label_paths", "ultralytics.utils.TQDM", "ultralytics.utils.checks.check_requirements", "shapely.geometry.Polygon" ], "chunk_id": "function_get_windows_d81e130c" }, { "content": "def get_window_obj(anno: Dict[str, Any], windows: np.ndarray, iof_thr: float = 0.7) -> List[np.ndarray]:\n \"\"\"Get objects for each window based on IoF threshold.\"\"\"\n h, w = anno[\"ori_size\"]\n label = anno[\"label\"]\n if len(label):\n label[:, 1::2] *= w\n label[:, 2::2] *= h\n iofs = bbox_iof(label[:, 1:], windows)\n # Unnormalized and misaligned coordinates\n return [(label[iofs[:, i] >= iof_thr]) for i in range(len(windows))] # window_anns\n else:\n return [np.zeros((0, 9), dtype=np.float32) for _ in range(len(windows))] # window_anns", "chunk_type": "function", "name": "get_window_obj", "file_path": "ultralytics\\ultralytics\\data\\split_dota.py", "start_line": 154, "end_line": 165, "start_col": 0, "end_col": 80, "parent_name": null, "docstring": "Get objects for each window based on IoF threshold.", "parameters": [ "anno: Dict[str, Any]", "windows: np.ndarray", "iof_thr: float" ], "return_type": "List[np.ndarray]", "decorators": [], "complexity_score": 4, "dependencies": [ "itertools", "glob.glob", "math.ceil", "pathlib.Path", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "cv2", "numpy", "PIL.Image", "ultralytics.data.utils.exif_size", "ultralytics.data.utils.img2label_paths", "ultralytics.utils.TQDM", "ultralytics.utils.checks.check_requirements", "shapely.geometry.Polygon" ], "chunk_id": "function_get_window_obj_acd10dbd" }, { "content": "def crop_and_save(\n anno: Dict[str, Any],\n windows: np.ndarray,\n window_objs: List[np.ndarray],\n im_dir: str,\n lb_dir: str,\n allow_background_images: bool = True,\n) -> None:\n \"\"\"\n Crop images and save new labels for each window.\n\n Args:\n anno (Dict[str, Any]): Annotation dict, including 'filepath', 'label', 'ori_size' as its keys.\n windows (np.ndarray): Array of windows coordinates with shape (N, 4).\n window_objs (List[np.ndarray]): A list of labels inside each window.\n im_dir (str): The output directory path of images.\n lb_dir (str): The output directory path of labels.\n allow_background_images (bool, optional): Whether to include background images without labels.\n\n Notes:\n The directory structure assumed for the DOTA dataset:\n - data_root\n - images\n - train\n - val\n - labels\n - train\n - val\n \"\"\"\n im = cv2.imread(anno[\"filepath\"])\n name = Path(anno[\"filepath\"]).stem\n for i, window in enumerate(windows):\n x_start, y_start, x_stop, y_stop = window.tolist()\n new_name = f\"{name}__{x_stop - x_start}__{x_start}___{y_start}\"\n patch_im = im[y_start:y_stop, x_start:x_stop]\n ph, pw = patch_im.shape[:2]\n\n label = window_objs[i]\n if len(label) or allow_background_images:\n cv2.imwrite(str(Path(im_dir) / f\"{new_name}.jpg\"), patch_im)\n if len(label):\n label[:, 1::2] -= x_start\n label[:, 2::2] -= y_start\n label[:, 1::2] /= pw\n label[:, 2::2] /= ph\n\n with open(Path(lb_dir) / f\"{new_name}.txt\", \"w\", encoding=\"utf-8\") as f:\n for lb in label:\n formatted_coords = [f\"{coord:.6g}\" for coord in lb[1:]]\n f.write(f\"{int(lb[0])} {' '.join(formatted_coords)}\\n\")", "chunk_type": "function", "name": "crop_and_save", "file_path": "ultralytics\\ultralytics\\data\\split_dota.py", "start_line": 168, "end_line": 217, "start_col": 0, "end_col": 75, "parent_name": null, "docstring": "Crop images and save new labels for each window.\n\nArgs:\n anno (Dict[str, Any]): Annotation dict, including 'filepath', 'label', 'ori_size' as its keys.\n windows (np.ndarray): Array of windows coordinates with shape (N, 4).\n window_objs (List[np.ndarray]): A list of labels inside each window.\n im_dir (str): The output directory path of images.\n lb_dir (str): The output directory path of labels.\n allow_background_images (bool, optional): Whether to include background images without labels.\n\nNotes:\n The directory structure assumed for the DOTA dataset:\n - data_root\n - images\n - train\n - val\n - labels\n - train\n - val", "parameters": [ "anno: Dict[str, Any]", "windows: np.ndarray", "window_objs: List[np.ndarray]", "im_dir: str", "lb_dir: str", "allow_background_images: bool" ], "return_type": "None", "decorators": [], "complexity_score": 6, "dependencies": [ "itertools", "glob.glob", "math.ceil", "pathlib.Path", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "cv2", "numpy", "PIL.Image", "ultralytics.data.utils.exif_size", "ultralytics.data.utils.img2label_paths", "ultralytics.utils.TQDM", "ultralytics.utils.checks.check_requirements", "shapely.geometry.Polygon" ], "chunk_id": "function_crop_and_save_d5e0007e" }, { "content": "def split_images_and_labels(\n data_root: str,\n save_dir: str,\n split: str = \"train\",\n crop_sizes: Tuple[int, ...] = (1024,),\n gaps: Tuple[int, ...] = (200,),\n) -> None:\n \"\"\"\n Split both images and labels for a given dataset split.\n\n Args:\n data_root (str): Root directory of the dataset.\n save_dir (str): Directory to save the split dataset.\n split (str, optional): The split data set, could be 'train' or 'val'.\n crop_sizes (Tuple[int, ...], optional): Tuple of crop sizes.\n gaps (Tuple[int, ...], optional): Tuple of gaps between crops.\n\n Notes:\n The directory structure assumed for the DOTA dataset:\n - data_root\n - images\n - split\n - labels\n - split\n and the output directory structure is:\n - save_dir\n - images\n - split\n - labels\n - split\n \"\"\"\n im_dir = Path(save_dir) / \"images\" / split\n im_dir.mkdir(parents=True, exist_ok=True)\n lb_dir = Path(save_dir) / \"labels\" / split\n lb_dir.mkdir(parents=True, exist_ok=True)\n\n annos = load_yolo_dota(data_root, split=split)\n for anno in TQDM(annos, total=len(annos), desc=split):\n windows = get_windows(anno[\"ori_size\"], crop_sizes, gaps)\n window_objs = get_window_obj(anno, windows)\n crop_and_save(anno, windows, window_objs, str(im_dir), str(lb_dir))", "chunk_type": "function", "name": "split_images_and_labels", "file_path": "ultralytics\\ultralytics\\data\\split_dota.py", "start_line": 220, "end_line": 260, "start_col": 0, "end_col": 75, "parent_name": null, "docstring": "Split both images and labels for a given dataset split.\n\nArgs:\n data_root (str): Root directory of the dataset.\n save_dir (str): Directory to save the split dataset.\n split (str, optional): The split data set, could be 'train' or 'val'.\n crop_sizes (Tuple[int, ...], optional): Tuple of crop sizes.\n gaps (Tuple[int, ...], optional): Tuple of gaps between crops.\n\nNotes:\n The directory structure assumed for the DOTA dataset:\n - data_root\n - images\n - split\n - labels\n - split\n and the output directory structure is:\n - save_dir\n - images\n - split\n - labels\n - split", "parameters": [ "data_root: str", "save_dir: str", "split: str", "crop_sizes: Tuple[int, ...]", "gaps: Tuple[int, ...]" ], "return_type": "None", "decorators": [], "complexity_score": 2, "dependencies": [ "itertools", "glob.glob", "math.ceil", "pathlib.Path", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "cv2", "numpy", "PIL.Image", "ultralytics.data.utils.exif_size", "ultralytics.data.utils.img2label_paths", "ultralytics.utils.TQDM", "ultralytics.utils.checks.check_requirements", "shapely.geometry.Polygon" ], "chunk_id": "function_split_images_and_labels_d96d6400" }, { "content": "def split_trainval(\n data_root: str, save_dir: str, crop_size: int = 1024, gap: int = 200, rates: Tuple[float, ...] = (1.0,)\n) -> None:\n \"\"\"\n Split train and val sets of DOTA dataset with multiple scaling rates.\n\n Args:\n data_root (str): Root directory of the dataset.\n save_dir (str): Directory to save the split dataset.\n crop_size (int, optional): Base crop size.\n gap (int, optional): Base gap between crops.\n rates (Tuple[float, ...], optional): Scaling rates for crop_size and gap.\n\n Notes:\n The directory structure assumed for the DOTA dataset:\n - data_root\n - images\n - train\n - val\n - labels\n - train\n - val\n and the output directory structure is:\n - save_dir\n - images\n - train\n - val\n - labels\n - train\n - val\n \"\"\"\n crop_sizes, gaps = [], []\n for r in rates:\n crop_sizes.append(int(crop_size / r))\n gaps.append(int(gap / r))\n for split in {\"train\", \"val\"}:\n split_images_and_labels(data_root, save_dir, split, crop_sizes, gaps)", "chunk_type": "function", "name": "split_trainval", "file_path": "ultralytics\\ultralytics\\data\\split_dota.py", "start_line": 263, "end_line": 299, "start_col": 0, "end_col": 77, "parent_name": null, "docstring": "Split train and val sets of DOTA dataset with multiple scaling rates.\n\nArgs:\n data_root (str): Root directory of the dataset.\n save_dir (str): Directory to save the split dataset.\n crop_size (int, optional): Base crop size.\n gap (int, optional): Base gap between crops.\n rates (Tuple[float, ...], optional): Scaling rates for crop_size and gap.\n\nNotes:\n The directory structure assumed for the DOTA dataset:\n - data_root\n - images\n - train\n - val\n - labels\n - train\n - val\n and the output directory structure is:\n - save_dir\n - images\n - train\n - val\n - labels\n - train\n - val", "parameters": [ "data_root: str", "save_dir: str", "crop_size: int", "gap: int", "rates: Tuple[float, ...]" ], "return_type": "None", "decorators": [], "complexity_score": 3, "dependencies": [ "itertools", "glob.glob", "math.ceil", "pathlib.Path", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "cv2", "numpy", "PIL.Image", "ultralytics.data.utils.exif_size", "ultralytics.data.utils.img2label_paths", "ultralytics.utils.TQDM", "ultralytics.utils.checks.check_requirements", "shapely.geometry.Polygon" ], "chunk_id": "function_split_trainval_a134a0e2" }, { "content": "def split_test(\n data_root: str, save_dir: str, crop_size: int = 1024, gap: int = 200, rates: Tuple[float, ...] = (1.0,)\n) -> None:\n \"\"\"\n Split test set of DOTA dataset, labels are not included within this set.\n\n Args:\n data_root (str): Root directory of the dataset.\n save_dir (str): Directory to save the split dataset.\n crop_size (int, optional): Base crop size.\n gap (int, optional): Base gap between crops.\n rates (Tuple[float, ...], optional): Scaling rates for crop_size and gap.\n\n Notes:\n The directory structure assumed for the DOTA dataset:\n - data_root\n - images\n - test\n and the output directory structure is:\n - save_dir\n - images\n - test\n \"\"\"\n crop_sizes, gaps = [], []\n for r in rates:\n crop_sizes.append(int(crop_size / r))\n gaps.append(int(gap / r))\n save_dir = Path(save_dir) / \"images\" / \"test\"\n save_dir.mkdir(parents=True, exist_ok=True)\n\n im_dir = Path(data_root) / \"images\" / \"test\"\n assert im_dir.exists(), f\"Can't find {im_dir}, please check your data root.\"\n im_files = glob(str(im_dir / \"*\"))\n for im_file in TQDM(im_files, total=len(im_files), desc=\"test\"):\n w, h = exif_size(Image.open(im_file))\n windows = get_windows((h, w), crop_sizes=crop_sizes, gaps=gaps)\n im = cv2.imread(im_file)\n name = Path(im_file).stem\n for window in windows:\n x_start, y_start, x_stop, y_stop = window.tolist()\n new_name = f\"{name}__{x_stop - x_start}__{x_start}___{y_start}\"\n patch_im = im[y_start:y_stop, x_start:x_stop]\n cv2.imwrite(str(save_dir / f\"{new_name}.jpg\"), patch_im)", "chunk_type": "function", "name": "split_test", "file_path": "ultralytics\\ultralytics\\data\\split_dota.py", "start_line": 302, "end_line": 344, "start_col": 0, "end_col": 68, "parent_name": null, "docstring": "Split test set of DOTA dataset, labels are not included within this set.\n\nArgs:\n data_root (str): Root directory of the dataset.\n save_dir (str): Directory to save the split dataset.\n crop_size (int, optional): Base crop size.\n gap (int, optional): Base gap between crops.\n rates (Tuple[float, ...], optional): Scaling rates for crop_size and gap.\n\nNotes:\n The directory structure assumed for the DOTA dataset:\n - data_root\n - images\n - test\n and the output directory structure is:\n - save_dir\n - images\n - test", "parameters": [ "data_root: str", "save_dir: str", "crop_size: int", "gap: int", "rates: Tuple[float, ...]" ], "return_type": "None", "decorators": [], "complexity_score": 4, "dependencies": [ "itertools", "glob.glob", "math.ceil", "pathlib.Path", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "cv2", "numpy", "PIL.Image", "ultralytics.data.utils.exif_size", "ultralytics.data.utils.img2label_paths", "ultralytics.utils.TQDM", "ultralytics.utils.checks.check_requirements", "shapely.geometry.Polygon" ], "chunk_id": "function_split_test_53e094ed" }, { "content": "import json", "chunk_type": "import", "name": "json", "file_path": "ultralytics\\ultralytics\\data\\utils.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 11, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_json_e62151db" }, { "content": "import os", "chunk_type": "import", "name": "os", "file_path": "ultralytics\\ultralytics\\data\\utils.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 9, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_os_9111e510" }, { "content": "import random", "chunk_type": "import", "name": "random", "file_path": "ultralytics\\ultralytics\\data\\utils.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 13, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_random_74284ca7" }, { "content": "import subprocess", "chunk_type": "import", "name": "subprocess", "file_path": "ultralytics\\ultralytics\\data\\utils.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 17, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_subprocess_40354b1c" }, { "content": "import time", "chunk_type": "import", "name": "time", "file_path": "ultralytics\\ultralytics\\data\\utils.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 11, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_time_6d6d25c9" }, { "content": "import zipfile", "chunk_type": "import", "name": "zipfile", "file_path": "ultralytics\\ultralytics\\data\\utils.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 14, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_zipfile_a9f6a7c2" }, { "content": "from multiprocessing.pool import ThreadPool", "chunk_type": "import", "name": "ThreadPool", "file_path": "ultralytics\\ultralytics\\data\\utils.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 43, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_ThreadPool_57fd187a" }, { "content": "from pathlib import Path", "chunk_type": "import", "name": "Path", "file_path": "ultralytics\\ultralytics\\data\\utils.py", "start_line": 10, "end_line": 10, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Path_c0fe3329" }, { "content": "from tarfile import is_tarfile", "chunk_type": "import", "name": "is_tarfile", "file_path": "ultralytics\\ultralytics\\data\\utils.py", "start_line": 11, "end_line": 11, "start_col": 0, "end_col": 30, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_is_tarfile_e99a9cbc" }, { "content": "from typing import Any, Dict, List, Tuple, Union", "chunk_type": "import", "name": "Any, Dict, List, Tuple, Union", "file_path": "ultralytics\\ultralytics\\data\\utils.py", "start_line": 12, "end_line": 12, "start_col": 0, "end_col": 48, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Any, Dict, List, Tuple, Union_2680e652" }, { "content": "import cv2", "chunk_type": "import", "name": "cv2", "file_path": "ultralytics\\ultralytics\\data\\utils.py", "start_line": 14, "end_line": 14, "start_col": 0, "end_col": 10, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_cv2_6beaf989" }, { "content": "import numpy as np", "chunk_type": "import", "name": "numpy", "file_path": "ultralytics\\ultralytics\\data\\utils.py", "start_line": 15, "end_line": 15, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_numpy_fa68f5a8" }, { "content": "from PIL import Image, ImageOps", "chunk_type": "import", "name": "Image, ImageOps", "file_path": "ultralytics\\ultralytics\\data\\utils.py", "start_line": 16, "end_line": 16, "start_col": 0, "end_col": 31, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Image, ImageOps_11916bba" }, { "content": "from ultralytics.nn.autobackend import check_class_names", "chunk_type": "import", "name": "check_class_names", "file_path": "ultralytics\\ultralytics\\data\\utils.py", "start_line": 18, "end_line": 18, "start_col": 0, "end_col": 56, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_check_class_names_2539e0c5" }, { "content": "from ultralytics.utils import (\n DATASETS_DIR,\n LOGGER,\n MACOS,\n NUM_THREADS,\n ROOT,\n SETTINGS_FILE,\n TQDM,\n YAML,\n clean_url,\n colorstr,\n emojis,\n is_dir_writeable,\n)", "chunk_type": "import", "name": "DATASETS_DIR, LOGGER, MACOS, NUM_THREADS, ROOT, SETTINGS_FILE, TQDM, YAML, clean_url, colorstr, emojis, is_dir_writeable", "file_path": "ultralytics\\ultralytics\\data\\utils.py", "start_line": 19, "end_line": 32, "start_col": 0, "end_col": 1, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_DATASETS_DIR, LOGGER, MACOS, NUM_THREADS, ROOT, SETTINGS_FILE, TQDM, YAML, clean_url, colorstr, emojis, is_dir_writeable_92bf4d56" }, { "content": "from ultralytics.utils.checks import check_file, check_font, is_ascii", "chunk_type": "import", "name": "check_file, check_font, is_ascii", "file_path": "ultralytics\\ultralytics\\data\\utils.py", "start_line": 33, "end_line": 33, "start_col": 0, "end_col": 69, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_check_file, check_font, is_ascii_7ff39bf3" }, { "content": "from ultralytics.utils.downloads import download, safe_download, unzip_file", "chunk_type": "import", "name": "download, safe_download, unzip_file", "file_path": "ultralytics\\ultralytics\\data\\utils.py", "start_line": 34, "end_line": 34, "start_col": 0, "end_col": 75, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_download, safe_download, unzip_file_643a6f1d" }, { "content": "from ultralytics.utils.ops import segments2boxes", "chunk_type": "import", "name": "segments2boxes", "file_path": "ultralytics\\ultralytics\\data\\utils.py", "start_line": 35, "end_line": 35, "start_col": 0, "end_col": 48, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_segments2boxes_df54c1df" }, { "content": "HELP_URL = \"See https://docs.ultralytics.com/datasets for dataset formatting guidance.\"", "chunk_type": "variable", "name": "HELP_URL", "file_path": "ultralytics\\ultralytics\\data\\utils.py", "start_line": 37, "end_line": 37, "start_col": 0, "end_col": 87, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_HELP_URL_97ff058e" }, { "content": "IMG_FORMATS = {\"bmp\", \"dng\", \"jpeg\", \"jpg\", \"mpo\", \"png\", \"tif\", \"tiff\", \"webp\", \"pfm\", \"heic\"} # image suffixes", "chunk_type": "variable", "name": "IMG_FORMATS", "file_path": "ultralytics\\ultralytics\\data\\utils.py", "start_line": 38, "end_line": 38, "start_col": 0, "end_col": 95, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_IMG_FORMATS_acf4fbb5" }, { "content": "VID_FORMATS = {\"asf\", \"avi\", \"gif\", \"m4v\", \"mkv\", \"mov\", \"mp4\", \"mpeg\", \"mpg\", \"ts\", \"wmv\", \"webm\"} # video suffixes", "chunk_type": "variable", "name": "VID_FORMATS", "file_path": "ultralytics\\ultralytics\\data\\utils.py", "start_line": 39, "end_line": 39, "start_col": 0, "end_col": 99, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_VID_FORMATS_95565efd" }, { "content": "PIN_MEMORY = str(os.getenv(\"PIN_MEMORY\", not MACOS)).lower() == \"true\" # global pin_memory for dataloaders", "chunk_type": "variable", "name": "PIN_MEMORY", "file_path": "ultralytics\\ultralytics\\data\\utils.py", "start_line": 40, "end_line": 40, "start_col": 0, "end_col": 70, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_PIN_MEMORY_9d3bdcaf" }, { "content": "FORMATS_HELP_MSG = f\"Supported formats are:\\nimages: {IMG_FORMATS}\\nvideos: {VID_FORMATS}\"", "chunk_type": "variable", "name": "FORMATS_HELP_MSG", "file_path": "ultralytics\\ultralytics\\data\\utils.py", "start_line": 41, "end_line": 41, "start_col": 0, "end_col": 90, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_FORMATS_HELP_MSG_82b63698" }, { "content": "def img2label_paths(img_paths: List[str]) -> List[str]:\n \"\"\"Convert image paths to label paths by replacing 'images' with 'labels' and extension with '.txt'.\"\"\"\n sa, sb = f\"{os.sep}images{os.sep}\", f\"{os.sep}labels{os.sep}\" # /images/, /labels/ substrings\n return [sb.join(x.rsplit(sa, 1)).rsplit(\".\", 1)[0] + \".txt\" for x in img_paths]", "chunk_type": "function", "name": "img2label_paths", "file_path": "ultralytics\\ultralytics\\data\\utils.py", "start_line": 44, "end_line": 47, "start_col": 0, "end_col": 83, "parent_name": null, "docstring": "Convert image paths to label paths by replacing 'images' with 'labels' and extension with '.txt'.", "parameters": [ "img_paths: List[str]" ], "return_type": "List[str]", "decorators": [], "complexity_score": 2, "dependencies": [ "json", "os", "random", "subprocess", "time", "zipfile", "multiprocessing.pool.ThreadPool", "pathlib.Path", "tarfile.is_tarfile", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "typing.Union", "cv2", "numpy", "PIL.Image", "PIL.ImageOps", "ultralytics.nn.autobackend.check_class_names", "ultralytics.utils.DATASETS_DIR", "ultralytics.utils.LOGGER", "ultralytics.utils.MACOS", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.ROOT", "ultralytics.utils.SETTINGS_FILE", "ultralytics.utils.TQDM", "ultralytics.utils.YAML", "ultralytics.utils.clean_url", "ultralytics.utils.colorstr", "ultralytics.utils.emojis", "ultralytics.utils.is_dir_writeable", "ultralytics.utils.checks.check_file", "ultralytics.utils.checks.check_font", "ultralytics.utils.checks.is_ascii", "ultralytics.utils.downloads.download", "ultralytics.utils.downloads.safe_download", "ultralytics.utils.downloads.unzip_file", "ultralytics.utils.ops.segments2boxes", "matplotlib.pyplot", "ultralytics.utils.plotting.colors", "gc", "ultralytics.data.YOLODataset", "ultralytics.data.split.split_classify_dataset", "torchvision.datasets.ImageFolder", "ultralytics.data.YOLODataset" ], "chunk_id": "function_img2label_paths_cee5558d" }, { "content": "def check_file_speeds(\n files: List[str], threshold_ms: float = 10, threshold_mb: float = 50, max_files: int = 5, prefix: str = \"\"\n):\n \"\"\"\n Check dataset file access speed and provide performance feedback.\n\n This function tests the access speed of dataset files by measuring ping (stat call) time and read speed.\n It samples up to 5 files from the provided list and warns if access times exceed the threshold.\n\n Args:\n files (List[str]): List of file paths to check for access speed.\n threshold_ms (float, optional): Threshold in milliseconds for ping time warnings.\n threshold_mb (float, optional): Threshold in megabytes per second for read speed warnings.\n max_files (int, optional): The maximum number of files to check.\n prefix (str, optional): Prefix string to add to log messages.\n\n Examples:\n >>> from pathlib import Path\n >>> image_files = list(Path(\"dataset/images\").glob(\"*.jpg\"))\n >>> check_file_speeds(image_files, threshold_ms=15)\n \"\"\"\n if not files or len(files) == 0:\n LOGGER.warning(f\"{prefix}Image speed checks: No files to check\")\n return\n\n # Sample files (max 5)\n files = random.sample(files, min(max_files, len(files)))\n\n # Test ping (stat time)\n ping_times = []\n file_sizes = []\n read_speeds = []\n\n for f in files:\n try:\n # Measure ping (stat call)\n start = time.perf_counter()\n file_size = os.stat(f).st_size\n ping_times.append((time.perf_counter() - start) * 1000) # ms\n file_sizes.append(file_size)\n\n # Measure read speed\n start = time.perf_counter()\n with open(f, \"rb\") as file_obj:\n _ = file_obj.read()\n read_time = time.perf_counter() - start\n if read_time > 0: # Avoid division by zero\n read_speeds.append(file_size / (1 << 20) / read_time) # MB/s\n except Exception:\n pass\n\n if not ping_times:\n LOGGER.warning(f\"{prefix}Image speed checks: failed to access files\")\n return\n\n # Calculate stats with uncertainties\n avg_ping = np.mean(ping_times)\n std_ping = np.std(ping_times, ddof=1) if len(ping_times) > 1 else 0\n size_msg = f\", size: {np.mean(file_sizes) / (1 << 10):.1f} KB\"\n ping_msg = f\"ping: {avg_ping:.1f}±{std_ping:.1f} ms\"\n\n if read_speeds:\n avg_speed = np.mean(read_speeds)\n std_speed = np.std(read_speeds, ddof=1) if len(read_speeds) > 1 else 0\n speed_msg = f\", read: {avg_speed:.1f}±{std_speed:.1f} MB/s\"\n else:\n speed_msg = \"\"\n\n if avg_ping < threshold_ms or avg_speed < threshold_mb:\n LOGGER.info(f\"{prefix}Fast image access ✅ ({ping_msg}{speed_msg}{size_msg})\")\n else:\n LOGGER.warning(\n f\"{prefix}Slow image access detected ({ping_msg}{speed_msg}{size_msg}). \"\n f\"Use local storage instead of remote/mounted storage for better performance. \"\n f\"See https://docs.ultralytics.com/guides/model-training-tips/\"\n )", "chunk_type": "function", "name": "check_file_speeds", "file_path": "ultralytics\\ultralytics\\data\\utils.py", "start_line": 50, "end_line": 125, "start_col": 0, "end_col": 9, "parent_name": null, "docstring": "Check dataset file access speed and provide performance feedback.\n\nThis function tests the access speed of dataset files by measuring ping (stat call) time and read speed.\nIt samples up to 5 files from the provided list and warns if access times exceed the threshold.\n\nArgs:\n files (List[str]): List of file paths to check for access speed.\n threshold_ms (float, optional): Threshold in milliseconds for ping time warnings.\n threshold_mb (float, optional): Threshold in megabytes per second for read speed warnings.\n max_files (int, optional): The maximum number of files to check.\n prefix (str, optional): Prefix string to add to log messages.\n\nExamples:\n >>> from pathlib import Path\n >>> image_files = list(Path(\"dataset/images\").glob(\"*.jpg\"))\n >>> check_file_speeds(image_files, threshold_ms=15)", "parameters": [ "files: List[str]", "threshold_ms: float", "threshold_mb: float", "max_files: int", "prefix: str" ], "return_type": null, "decorators": [], "complexity_score": 8, "dependencies": [ "json", "os", "random", "subprocess", "time", "zipfile", "multiprocessing.pool.ThreadPool", "pathlib.Path", "tarfile.is_tarfile", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "typing.Union", "cv2", "numpy", "PIL.Image", "PIL.ImageOps", "ultralytics.nn.autobackend.check_class_names", "ultralytics.utils.DATASETS_DIR", "ultralytics.utils.LOGGER", "ultralytics.utils.MACOS", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.ROOT", "ultralytics.utils.SETTINGS_FILE", "ultralytics.utils.TQDM", "ultralytics.utils.YAML", "ultralytics.utils.clean_url", "ultralytics.utils.colorstr", "ultralytics.utils.emojis", "ultralytics.utils.is_dir_writeable", "ultralytics.utils.checks.check_file", "ultralytics.utils.checks.check_font", "ultralytics.utils.checks.is_ascii", "ultralytics.utils.downloads.download", "ultralytics.utils.downloads.safe_download", "ultralytics.utils.downloads.unzip_file", "ultralytics.utils.ops.segments2boxes", "matplotlib.pyplot", "ultralytics.utils.plotting.colors", "gc", "ultralytics.data.YOLODataset", "ultralytics.data.split.split_classify_dataset", "torchvision.datasets.ImageFolder", "ultralytics.data.YOLODataset" ], "chunk_id": "function_check_file_speeds_6460b8bc" }, { "content": "def get_hash(paths: List[str]) -> str:\n \"\"\"Return a single hash value of a list of paths (files or dirs).\"\"\"\n size = 0\n for p in paths:\n try:\n size += os.stat(p).st_size\n except OSError:\n continue\n h = __import__(\"hashlib\").sha256(str(size).encode()) # hash sizes\n h.update(\"\".join(paths).encode()) # hash paths\n return h.hexdigest() # return hash", "chunk_type": "function", "name": "get_hash", "file_path": "ultralytics\\ultralytics\\data\\utils.py", "start_line": 128, "end_line": 138, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": "Return a single hash value of a list of paths (files or dirs).", "parameters": [ "paths: List[str]" ], "return_type": "str", "decorators": [], "complexity_score": 3, "dependencies": [ "json", "os", "random", "subprocess", "time", "zipfile", "multiprocessing.pool.ThreadPool", "pathlib.Path", "tarfile.is_tarfile", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "typing.Union", "cv2", "numpy", "PIL.Image", "PIL.ImageOps", "ultralytics.nn.autobackend.check_class_names", "ultralytics.utils.DATASETS_DIR", "ultralytics.utils.LOGGER", "ultralytics.utils.MACOS", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.ROOT", "ultralytics.utils.SETTINGS_FILE", "ultralytics.utils.TQDM", "ultralytics.utils.YAML", "ultralytics.utils.clean_url", "ultralytics.utils.colorstr", "ultralytics.utils.emojis", "ultralytics.utils.is_dir_writeable", "ultralytics.utils.checks.check_file", "ultralytics.utils.checks.check_font", "ultralytics.utils.checks.is_ascii", "ultralytics.utils.downloads.download", "ultralytics.utils.downloads.safe_download", "ultralytics.utils.downloads.unzip_file", "ultralytics.utils.ops.segments2boxes", "matplotlib.pyplot", "ultralytics.utils.plotting.colors", "gc", "ultralytics.data.YOLODataset", "ultralytics.data.split.split_classify_dataset", "torchvision.datasets.ImageFolder", "ultralytics.data.YOLODataset" ], "chunk_id": "function_get_hash_f8994d15" }, { "content": "def exif_size(img: Image.Image) -> Tuple[int, int]:\n \"\"\"Return exif-corrected PIL size.\"\"\"\n s = img.size # (width, height)\n if img.format == \"JPEG\": # only support JPEG images\n try:\n if exif := img.getexif():\n rotation = exif.get(274, None) # the EXIF key for the orientation tag is 274\n if rotation in {6, 8}: # rotation 270 or 90\n s = s[1], s[0]\n except Exception:\n pass\n return s", "chunk_type": "function", "name": "exif_size", "file_path": "ultralytics\\ultralytics\\data\\utils.py", "start_line": 141, "end_line": 152, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": "Return exif-corrected PIL size.", "parameters": [ "img: Image.Image" ], "return_type": "Tuple[int, int]", "decorators": [], "complexity_score": 5, "dependencies": [ "json", "os", "random", "subprocess", "time", "zipfile", "multiprocessing.pool.ThreadPool", "pathlib.Path", "tarfile.is_tarfile", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "typing.Union", "cv2", "numpy", "PIL.Image", "PIL.ImageOps", "ultralytics.nn.autobackend.check_class_names", "ultralytics.utils.DATASETS_DIR", "ultralytics.utils.LOGGER", "ultralytics.utils.MACOS", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.ROOT", "ultralytics.utils.SETTINGS_FILE", "ultralytics.utils.TQDM", "ultralytics.utils.YAML", "ultralytics.utils.clean_url", "ultralytics.utils.colorstr", "ultralytics.utils.emojis", "ultralytics.utils.is_dir_writeable", "ultralytics.utils.checks.check_file", "ultralytics.utils.checks.check_font", "ultralytics.utils.checks.is_ascii", "ultralytics.utils.downloads.download", "ultralytics.utils.downloads.safe_download", "ultralytics.utils.downloads.unzip_file", "ultralytics.utils.ops.segments2boxes", "matplotlib.pyplot", "ultralytics.utils.plotting.colors", "gc", "ultralytics.data.YOLODataset", "ultralytics.data.split.split_classify_dataset", "torchvision.datasets.ImageFolder", "ultralytics.data.YOLODataset" ], "chunk_id": "function_exif_size_6e4941ad" }, { "content": "def verify_image(args: Tuple) -> Tuple:\n \"\"\"Verify one image.\"\"\"\n (im_file, cls), prefix = args\n # Number (found, corrupt), message\n nf, nc, msg = 0, 0, \"\"\n try:\n im = Image.open(im_file)\n im.verify() # PIL verify\n shape = exif_size(im) # image size\n shape = (shape[1], shape[0]) # hw\n assert (shape[0] > 9) & (shape[1] > 9), f\"image size {shape} <10 pixels\"\n assert im.format.lower() in IMG_FORMATS, f\"Invalid image format {im.format}. {FORMATS_HELP_MSG}\"\n if im.format.lower() in {\"jpg\", \"jpeg\"}:\n with open(im_file, \"rb\") as f:\n f.seek(-2, 2)\n if f.read() != b\"\\xff\\xd9\": # corrupt JPEG\n ImageOps.exif_transpose(Image.open(im_file)).save(im_file, \"JPEG\", subsampling=0, quality=100)\n msg = f\"{prefix}{im_file}: corrupt JPEG restored and saved\"\n nf = 1\n except Exception as e:\n nc = 1\n msg = f\"{prefix}{im_file}: ignoring corrupt image/label: {e}\"\n return (im_file, cls), nf, nc, msg", "chunk_type": "function", "name": "verify_image", "file_path": "ultralytics\\ultralytics\\data\\utils.py", "start_line": 155, "end_line": 177, "start_col": 0, "end_col": 38, "parent_name": null, "docstring": "Verify one image.", "parameters": [ "args: Tuple" ], "return_type": "Tuple", "decorators": [], "complexity_score": 4, "dependencies": [ "json", "os", "random", "subprocess", "time", "zipfile", "multiprocessing.pool.ThreadPool", "pathlib.Path", "tarfile.is_tarfile", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "typing.Union", "cv2", "numpy", "PIL.Image", "PIL.ImageOps", "ultralytics.nn.autobackend.check_class_names", "ultralytics.utils.DATASETS_DIR", "ultralytics.utils.LOGGER", "ultralytics.utils.MACOS", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.ROOT", "ultralytics.utils.SETTINGS_FILE", "ultralytics.utils.TQDM", "ultralytics.utils.YAML", "ultralytics.utils.clean_url", "ultralytics.utils.colorstr", "ultralytics.utils.emojis", "ultralytics.utils.is_dir_writeable", "ultralytics.utils.checks.check_file", "ultralytics.utils.checks.check_font", "ultralytics.utils.checks.is_ascii", "ultralytics.utils.downloads.download", "ultralytics.utils.downloads.safe_download", "ultralytics.utils.downloads.unzip_file", "ultralytics.utils.ops.segments2boxes", "matplotlib.pyplot", "ultralytics.utils.plotting.colors", "gc", "ultralytics.data.YOLODataset", "ultralytics.data.split.split_classify_dataset", "torchvision.datasets.ImageFolder", "ultralytics.data.YOLODataset" ], "chunk_id": "function_verify_image_4c3b5e75" }, { "content": "def verify_image_label(args: Tuple) -> List:\n \"\"\"Verify one image-label pair.\"\"\"\n im_file, lb_file, prefix, keypoint, num_cls, nkpt, ndim, single_cls = args\n # Number (missing, found, empty, corrupt), message, segments, keypoints\n nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, \"\", [], None\n try:\n # Verify images\n im = Image.open(im_file)\n im.verify() # PIL verify\n shape = exif_size(im) # image size\n shape = (shape[1], shape[0]) # hw\n assert (shape[0] > 9) & (shape[1] > 9), f\"image size {shape} <10 pixels\"\n assert im.format.lower() in IMG_FORMATS, f\"invalid image format {im.format}. {FORMATS_HELP_MSG}\"\n if im.format.lower() in {\"jpg\", \"jpeg\"}:\n with open(im_file, \"rb\") as f:\n f.seek(-2, 2)\n if f.read() != b\"\\xff\\xd9\": # corrupt JPEG\n ImageOps.exif_transpose(Image.open(im_file)).save(im_file, \"JPEG\", subsampling=0, quality=100)\n msg = f\"{prefix}{im_file}: corrupt JPEG restored and saved\"\n\n # Verify labels\n if os.path.isfile(lb_file):\n nf = 1 # label found\n with open(lb_file, encoding=\"utf-8\") as f:\n lb = [x.split() for x in f.read().strip().splitlines() if len(x)]\n if any(len(x) > 6 for x in lb) and (not keypoint): # is segment\n classes = np.array([x[0] for x in lb], dtype=np.float32)\n segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in lb] # (cls, xy1...)\n lb = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh)\n lb = np.array(lb, dtype=np.float32)\n if nl := len(lb):\n if keypoint:\n assert lb.shape[1] == (5 + nkpt * ndim), f\"labels require {(5 + nkpt * ndim)} columns each\"\n points = lb[:, 5:].reshape(-1, ndim)[:, :2]\n else:\n assert lb.shape[1] == 5, f\"labels require 5 columns, {lb.shape[1]} columns detected\"\n points = lb[:, 1:]\n # Coordinate points check with 1% tolerance\n assert points.max() <= 1.01, f\"non-normalized or out of bounds coordinates {points[points > 1.01]}\"\n assert lb.min() >= -0.01, f\"negative class labels {lb[lb < -0.01]}\"\n\n # All labels\n if single_cls:\n lb[:, 0] = 0\n max_cls = lb[:, 0].max() # max label count\n assert max_cls < num_cls, (\n f\"Label class {int(max_cls)} exceeds dataset class count {num_cls}. \"\n f\"Possible class labels are 0-{num_cls - 1}\"\n )\n _, i = np.unique(lb, axis=0, return_index=True)\n if len(i) < nl: # duplicate row check\n lb = lb[i] # remove duplicates\n if segments:\n segments = [segments[x] for x in i]\n msg = f\"{prefix}{im_file}: {nl - len(i)} duplicate labels removed\"\n else:\n ne = 1 # label empty\n lb = np.zeros((0, (5 + nkpt * ndim) if keypoint else 5), dtype=np.float32)\n else:\n nm = 1 # label missing\n lb = np.zeros((0, (5 + nkpt * ndim) if keypoints else 5), dtype=np.float32)\n if keypoint:\n keypoints = lb[:, 5:].reshape(-1, nkpt, ndim)\n if ndim == 2:\n kpt_mask = np.where((keypoints[..., 0] < 0) | (keypoints[..., 1] < 0), 0.0, 1.0).astype(np.float32)\n keypoints = np.concatenate([keypoints, kpt_mask[..., None]], axis=-1) # (nl, nkpt, 3)\n lb = lb[:, :5]\n return im_file, lb, shape, segments, keypoints, nm, nf, ne, nc, msg\n except Exception as e:\n nc = 1\n msg = f\"{prefix}{im_file}: ignoring corrupt image/label: {e}\"\n return [None, None, None, None, None, nm, nf, ne, nc, msg]", "chunk_type": "function", "name": "verify_image_label", "file_path": "ultralytics\\ultralytics\\data\\utils.py", "start_line": 180, "end_line": 251, "start_col": 0, "end_col": 66, "parent_name": null, "docstring": "Verify one image-label pair.", "parameters": [ "args: Tuple" ], "return_type": "List", "decorators": [], "complexity_score": 18, "dependencies": [ "json", "os", "random", "subprocess", "time", "zipfile", "multiprocessing.pool.ThreadPool", "pathlib.Path", "tarfile.is_tarfile", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "typing.Union", "cv2", "numpy", "PIL.Image", "PIL.ImageOps", "ultralytics.nn.autobackend.check_class_names", "ultralytics.utils.DATASETS_DIR", "ultralytics.utils.LOGGER", "ultralytics.utils.MACOS", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.ROOT", "ultralytics.utils.SETTINGS_FILE", "ultralytics.utils.TQDM", "ultralytics.utils.YAML", "ultralytics.utils.clean_url", "ultralytics.utils.colorstr", "ultralytics.utils.emojis", "ultralytics.utils.is_dir_writeable", "ultralytics.utils.checks.check_file", "ultralytics.utils.checks.check_font", "ultralytics.utils.checks.is_ascii", "ultralytics.utils.downloads.download", "ultralytics.utils.downloads.safe_download", "ultralytics.utils.downloads.unzip_file", "ultralytics.utils.ops.segments2boxes", "matplotlib.pyplot", "ultralytics.utils.plotting.colors", "gc", "ultralytics.data.YOLODataset", "ultralytics.data.split.split_classify_dataset", "torchvision.datasets.ImageFolder", "ultralytics.data.YOLODataset" ], "chunk_id": "function_verify_image_label_86beea4b" }, { "content": "def visualize_image_annotations(image_path: str, txt_path: str, label_map: Dict[int, str]):\n \"\"\"\n Visualize YOLO annotations (bounding boxes and class labels) on an image.\n\n This function reads an image and its corresponding annotation file in YOLO format, then\n draws bounding boxes around detected objects and labels them with their respective class names.\n The bounding box colors are assigned based on the class ID, and the text color is dynamically\n adjusted for readability, depending on the background color's luminance.\n\n Args:\n image_path (str): The path to the image file to annotate, and it can be in formats supported by PIL.\n txt_path (str): The path to the annotation file in YOLO format, that should contain one line per object.\n label_map (Dict[int, str]): A dictionary that maps class IDs (integers) to class labels (strings).\n\n Examples:\n >>> label_map = {0: \"cat\", 1: \"dog\", 2: \"bird\"} # It should include all annotated classes details\n >>> visualize_image_annotations(\"path/to/image.jpg\", \"path/to/annotations.txt\", label_map)\n \"\"\"\n import matplotlib.pyplot as plt\n\n from ultralytics.utils.plotting import colors\n\n img = np.array(Image.open(image_path))\n img_height, img_width = img.shape[:2]\n annotations = []\n with open(txt_path, encoding=\"utf-8\") as file:\n for line in file:\n class_id, x_center, y_center, width, height = map(float, line.split())\n x = (x_center - width / 2) * img_width\n y = (y_center - height / 2) * img_height\n w = width * img_width\n h = height * img_height\n annotations.append((x, y, w, h, int(class_id)))\n _, ax = plt.subplots(1) # Plot the image and annotations\n for x, y, w, h, label in annotations:\n color = tuple(c / 255 for c in colors(label, True)) # Get and normalize the RGB color\n rect = plt.Rectangle((x, y), w, h, linewidth=2, edgecolor=color, facecolor=\"none\") # Create a rectangle\n ax.add_patch(rect)\n luminance = 0.2126 * color[0] + 0.7152 * color[1] + 0.0722 * color[2] # Formula for luminance\n ax.text(x, y - 5, label_map[label], color=\"white\" if luminance < 0.5 else \"black\", backgroundcolor=color)\n ax.imshow(img)\n plt.show()", "chunk_type": "function", "name": "visualize_image_annotations", "file_path": "ultralytics\\ultralytics\\data\\utils.py", "start_line": 254, "end_line": 295, "start_col": 0, "end_col": 14, "parent_name": null, "docstring": "Visualize YOLO annotations (bounding boxes and class labels) on an image.\n\nThis function reads an image and its corresponding annotation file in YOLO format, then\ndraws bounding boxes around detected objects and labels them with their respective class names.\nThe bounding box colors are assigned based on the class ID, and the text color is dynamically\nadjusted for readability, depending on the background color's luminance.\n\nArgs:\n image_path (str): The path to the image file to annotate, and it can be in formats supported by PIL.\n txt_path (str): The path to the annotation file in YOLO format, that should contain one line per object.\n label_map (Dict[int, str]): A dictionary that maps class IDs (integers) to class labels (strings).\n\nExamples:\n >>> label_map = {0: \"cat\", 1: \"dog\", 2: \"bird\"} # It should include all annotated classes details\n >>> visualize_image_annotations(\"path/to/image.jpg\", \"path/to/annotations.txt\", label_map)", "parameters": [ "image_path: str", "txt_path: str", "label_map: Dict[int, str]" ], "return_type": null, "decorators": [], "complexity_score": 4, "dependencies": [ "json", "os", "random", "subprocess", "time", "zipfile", "multiprocessing.pool.ThreadPool", "pathlib.Path", "tarfile.is_tarfile", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "typing.Union", "cv2", "numpy", "PIL.Image", "PIL.ImageOps", "ultralytics.nn.autobackend.check_class_names", "ultralytics.utils.DATASETS_DIR", "ultralytics.utils.LOGGER", "ultralytics.utils.MACOS", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.ROOT", "ultralytics.utils.SETTINGS_FILE", "ultralytics.utils.TQDM", "ultralytics.utils.YAML", "ultralytics.utils.clean_url", "ultralytics.utils.colorstr", "ultralytics.utils.emojis", "ultralytics.utils.is_dir_writeable", "ultralytics.utils.checks.check_file", "ultralytics.utils.checks.check_font", "ultralytics.utils.checks.is_ascii", "ultralytics.utils.downloads.download", "ultralytics.utils.downloads.safe_download", "ultralytics.utils.downloads.unzip_file", "ultralytics.utils.ops.segments2boxes", "matplotlib.pyplot", "ultralytics.utils.plotting.colors", "gc", "ultralytics.data.YOLODataset", "ultralytics.data.split.split_classify_dataset", "torchvision.datasets.ImageFolder", "ultralytics.data.YOLODataset" ], "chunk_id": "function_visualize_image_annotations_b3e1beee" }, { "content": "def polygon2mask(\n imgsz: Tuple[int, int], polygons: List[np.ndarray], color: int = 1, downsample_ratio: int = 1\n) -> np.ndarray:\n \"\"\"\n Convert a list of polygons to a binary mask of the specified image size.\n\n Args:\n imgsz (Tuple[int, int]): The size of the image as (height, width).\n polygons (List[np.ndarray]): A list of polygons. Each polygon is an array with shape (N, M), where\n N is the number of polygons, and M is the number of points such that M % 2 = 0.\n color (int, optional): The color value to fill in the polygons on the mask.\n downsample_ratio (int, optional): Factor by which to downsample the mask.\n\n Returns:\n (np.ndarray): A binary mask of the specified image size with the polygons filled in.\n \"\"\"\n mask = np.zeros(imgsz, dtype=np.uint8)\n polygons = np.asarray(polygons, dtype=np.int32)\n polygons = polygons.reshape((polygons.shape[0], -1, 2))\n cv2.fillPoly(mask, polygons, color=color)\n nh, nw = (imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio)\n # Note: fillPoly first then resize is trying to keep the same loss calculation method when mask-ratio=1\n return cv2.resize(mask, (nw, nh))", "chunk_type": "function", "name": "polygon2mask", "file_path": "ultralytics\\ultralytics\\data\\utils.py", "start_line": 298, "end_line": 320, "start_col": 0, "end_col": 37, "parent_name": null, "docstring": "Convert a list of polygons to a binary mask of the specified image size.\n\nArgs:\n imgsz (Tuple[int, int]): The size of the image as (height, width).\n polygons (List[np.ndarray]): A list of polygons. Each polygon is an array with shape (N, M), where\n N is the number of polygons, and M is the number of points such that M % 2 = 0.\n color (int, optional): The color value to fill in the polygons on the mask.\n downsample_ratio (int, optional): Factor by which to downsample the mask.\n\nReturns:\n (np.ndarray): A binary mask of the specified image size with the polygons filled in.", "parameters": [ "imgsz: Tuple[int, int]", "polygons: List[np.ndarray]", "color: int", "downsample_ratio: int" ], "return_type": "np.ndarray", "decorators": [], "complexity_score": 1, "dependencies": [ "json", "os", "random", "subprocess", "time", "zipfile", "multiprocessing.pool.ThreadPool", "pathlib.Path", "tarfile.is_tarfile", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "typing.Union", "cv2", "numpy", "PIL.Image", "PIL.ImageOps", "ultralytics.nn.autobackend.check_class_names", "ultralytics.utils.DATASETS_DIR", "ultralytics.utils.LOGGER", "ultralytics.utils.MACOS", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.ROOT", "ultralytics.utils.SETTINGS_FILE", "ultralytics.utils.TQDM", "ultralytics.utils.YAML", "ultralytics.utils.clean_url", "ultralytics.utils.colorstr", "ultralytics.utils.emojis", "ultralytics.utils.is_dir_writeable", "ultralytics.utils.checks.check_file", "ultralytics.utils.checks.check_font", "ultralytics.utils.checks.is_ascii", "ultralytics.utils.downloads.download", "ultralytics.utils.downloads.safe_download", "ultralytics.utils.downloads.unzip_file", "ultralytics.utils.ops.segments2boxes", "matplotlib.pyplot", "ultralytics.utils.plotting.colors", "gc", "ultralytics.data.YOLODataset", "ultralytics.data.split.split_classify_dataset", "torchvision.datasets.ImageFolder", "ultralytics.data.YOLODataset" ], "chunk_id": "function_polygon2mask_7081fabf" }, { "content": "def polygons2masks(\n imgsz: Tuple[int, int], polygons: List[np.ndarray], color: int, downsample_ratio: int = 1\n) -> np.ndarray:\n \"\"\"\n Convert a list of polygons to a set of binary masks of the specified image size.\n\n Args:\n imgsz (Tuple[int, int]): The size of the image as (height, width).\n polygons (List[np.ndarray]): A list of polygons. Each polygon is an array with shape (N, M), where\n N is the number of polygons, and M is the number of points such that M % 2 = 0.\n color (int): The color value to fill in the polygons on the masks.\n downsample_ratio (int, optional): Factor by which to downsample each mask.\n\n Returns:\n (np.ndarray): A set of binary masks of the specified image size with the polygons filled in.\n \"\"\"\n return np.array([polygon2mask(imgsz, [x.reshape(-1)], color, downsample_ratio) for x in polygons])", "chunk_type": "function", "name": "polygons2masks", "file_path": "ultralytics\\ultralytics\\data\\utils.py", "start_line": 323, "end_line": 339, "start_col": 0, "end_col": 102, "parent_name": null, "docstring": "Convert a list of polygons to a set of binary masks of the specified image size.\n\nArgs:\n imgsz (Tuple[int, int]): The size of the image as (height, width).\n polygons (List[np.ndarray]): A list of polygons. Each polygon is an array with shape (N, M), where\n N is the number of polygons, and M is the number of points such that M % 2 = 0.\n color (int): The color value to fill in the polygons on the masks.\n downsample_ratio (int, optional): Factor by which to downsample each mask.\n\nReturns:\n (np.ndarray): A set of binary masks of the specified image size with the polygons filled in.", "parameters": [ "imgsz: Tuple[int, int]", "polygons: List[np.ndarray]", "color: int", "downsample_ratio: int" ], "return_type": "np.ndarray", "decorators": [], "complexity_score": 2, "dependencies": [ "json", "os", "random", "subprocess", "time", "zipfile", "multiprocessing.pool.ThreadPool", "pathlib.Path", "tarfile.is_tarfile", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "typing.Union", "cv2", "numpy", "PIL.Image", "PIL.ImageOps", "ultralytics.nn.autobackend.check_class_names", "ultralytics.utils.DATASETS_DIR", "ultralytics.utils.LOGGER", "ultralytics.utils.MACOS", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.ROOT", "ultralytics.utils.SETTINGS_FILE", "ultralytics.utils.TQDM", "ultralytics.utils.YAML", "ultralytics.utils.clean_url", "ultralytics.utils.colorstr", "ultralytics.utils.emojis", "ultralytics.utils.is_dir_writeable", "ultralytics.utils.checks.check_file", "ultralytics.utils.checks.check_font", "ultralytics.utils.checks.is_ascii", "ultralytics.utils.downloads.download", "ultralytics.utils.downloads.safe_download", "ultralytics.utils.downloads.unzip_file", "ultralytics.utils.ops.segments2boxes", "matplotlib.pyplot", "ultralytics.utils.plotting.colors", "gc", "ultralytics.data.YOLODataset", "ultralytics.data.split.split_classify_dataset", "torchvision.datasets.ImageFolder", "ultralytics.data.YOLODataset" ], "chunk_id": "function_polygons2masks_76f87607" }, { "content": "def polygons2masks_overlap(\n imgsz: Tuple[int, int], segments: List[np.ndarray], downsample_ratio: int = 1\n) -> Tuple[np.ndarray, np.ndarray]:\n \"\"\"Return a (640, 640) overlap mask.\"\"\"\n masks = np.zeros(\n (imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio),\n dtype=np.int32 if len(segments) > 255 else np.uint8,\n )\n areas = []\n ms = []\n for si in range(len(segments)):\n mask = polygon2mask(imgsz, [segments[si].reshape(-1)], downsample_ratio=downsample_ratio, color=1)\n ms.append(mask.astype(masks.dtype))\n areas.append(mask.sum())\n areas = np.asarray(areas)\n index = np.argsort(-areas)\n ms = np.array(ms)[index]\n for i in range(len(segments)):\n mask = ms[i] * (i + 1)\n masks = masks + mask\n masks = np.clip(masks, a_min=0, a_max=i + 1)\n return masks, index", "chunk_type": "function", "name": "polygons2masks_overlap", "file_path": "ultralytics\\ultralytics\\data\\utils.py", "start_line": 342, "end_line": 363, "start_col": 0, "end_col": 23, "parent_name": null, "docstring": "Return a (640, 640) overlap mask.", "parameters": [ "imgsz: Tuple[int, int]", "segments: List[np.ndarray]", "downsample_ratio: int" ], "return_type": "Tuple[np.ndarray, np.ndarray]", "decorators": [], "complexity_score": 3, "dependencies": [ "json", "os", "random", "subprocess", "time", "zipfile", "multiprocessing.pool.ThreadPool", "pathlib.Path", "tarfile.is_tarfile", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "typing.Union", "cv2", "numpy", "PIL.Image", "PIL.ImageOps", "ultralytics.nn.autobackend.check_class_names", "ultralytics.utils.DATASETS_DIR", "ultralytics.utils.LOGGER", "ultralytics.utils.MACOS", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.ROOT", "ultralytics.utils.SETTINGS_FILE", "ultralytics.utils.TQDM", "ultralytics.utils.YAML", "ultralytics.utils.clean_url", "ultralytics.utils.colorstr", "ultralytics.utils.emojis", "ultralytics.utils.is_dir_writeable", "ultralytics.utils.checks.check_file", "ultralytics.utils.checks.check_font", "ultralytics.utils.checks.is_ascii", "ultralytics.utils.downloads.download", "ultralytics.utils.downloads.safe_download", "ultralytics.utils.downloads.unzip_file", "ultralytics.utils.ops.segments2boxes", "matplotlib.pyplot", "ultralytics.utils.plotting.colors", "gc", "ultralytics.data.YOLODataset", "ultralytics.data.split.split_classify_dataset", "torchvision.datasets.ImageFolder", "ultralytics.data.YOLODataset" ], "chunk_id": "function_polygons2masks_overlap_5930dd6a" }, { "content": "def find_dataset_yaml(path: Path) -> Path:\n \"\"\"\n Find and return the YAML file associated with a Detect, Segment or Pose dataset.\n\n This function searches for a YAML file at the root level of the provided directory first, and if not found, it\n performs a recursive search. It prefers YAML files that have the same stem as the provided path.\n\n Args:\n path (Path): The directory path to search for the YAML file.\n\n Returns:\n (Path): The path of the found YAML file.\n \"\"\"\n files = list(path.glob(\"*.yaml\")) or list(path.rglob(\"*.yaml\")) # try root level first and then recursive\n assert files, f\"No YAML file found in '{path.resolve()}'\"\n if len(files) > 1:\n files = [f for f in files if f.stem == path.stem] # prefer *.yaml files that match\n assert len(files) == 1, f\"Expected 1 YAML file in '{path.resolve()}', but found {len(files)}.\\n{files}\"\n return files[0]", "chunk_type": "function", "name": "find_dataset_yaml", "file_path": "ultralytics\\ultralytics\\data\\utils.py", "start_line": 366, "end_line": 384, "start_col": 0, "end_col": 19, "parent_name": null, "docstring": "Find and return the YAML file associated with a Detect, Segment or Pose dataset.\n\nThis function searches for a YAML file at the root level of the provided directory first, and if not found, it\nperforms a recursive search. It prefers YAML files that have the same stem as the provided path.\n\nArgs:\n path (Path): The directory path to search for the YAML file.\n\nReturns:\n (Path): The path of the found YAML file.", "parameters": [ "path: Path" ], "return_type": "Path", "decorators": [], "complexity_score": 3, "dependencies": [ "json", "os", "random", "subprocess", "time", "zipfile", "multiprocessing.pool.ThreadPool", "pathlib.Path", "tarfile.is_tarfile", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "typing.Union", "cv2", "numpy", "PIL.Image", "PIL.ImageOps", "ultralytics.nn.autobackend.check_class_names", "ultralytics.utils.DATASETS_DIR", "ultralytics.utils.LOGGER", "ultralytics.utils.MACOS", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.ROOT", "ultralytics.utils.SETTINGS_FILE", "ultralytics.utils.TQDM", "ultralytics.utils.YAML", "ultralytics.utils.clean_url", "ultralytics.utils.colorstr", "ultralytics.utils.emojis", "ultralytics.utils.is_dir_writeable", "ultralytics.utils.checks.check_file", "ultralytics.utils.checks.check_font", "ultralytics.utils.checks.is_ascii", "ultralytics.utils.downloads.download", "ultralytics.utils.downloads.safe_download", "ultralytics.utils.downloads.unzip_file", "ultralytics.utils.ops.segments2boxes", "matplotlib.pyplot", "ultralytics.utils.plotting.colors", "gc", "ultralytics.data.YOLODataset", "ultralytics.data.split.split_classify_dataset", "torchvision.datasets.ImageFolder", "ultralytics.data.YOLODataset" ], "chunk_id": "function_find_dataset_yaml_a0eb279c" }, { "content": "def check_det_dataset(dataset: str, autodownload: bool = True) -> Dict[str, Any]:\n \"\"\"\n Download, verify, and/or unzip a dataset if not found locally.\n\n This function checks the availability of a specified dataset, and if not found, it has the option to download and\n unzip the dataset. It then reads and parses the accompanying YAML data, ensuring key requirements are met and also\n resolves paths related to the dataset.\n\n Args:\n dataset (str): Path to the dataset or dataset descriptor (like a YAML file).\n autodownload (bool, optional): Whether to automatically download the dataset if not found.\n\n Returns:\n (Dict[str, Any]): Parsed dataset information and paths.\n \"\"\"\n file = check_file(dataset)\n\n # Download (optional)\n extract_dir = \"\"\n if zipfile.is_zipfile(file) or is_tarfile(file):\n new_dir = safe_download(file, dir=DATASETS_DIR, unzip=True, delete=False)\n file = find_dataset_yaml(DATASETS_DIR / new_dir)\n extract_dir, autodownload = file.parent, False\n\n # Read YAML\n data = YAML.load(file, append_filename=True) # dictionary\n\n # Checks\n for k in \"train\", \"val\":\n if k not in data:\n if k != \"val\" or \"validation\" not in data:\n raise SyntaxError(\n emojis(f\"{dataset} '{k}:' key missing ❌.\\n'train' and 'val' are required in all data YAMLs.\")\n )\n LOGGER.warning(\"renaming data YAML 'validation' key to 'val' to match YOLO format.\")\n data[\"val\"] = data.pop(\"validation\") # replace 'validation' key with 'val' key\n if \"names\" not in data and \"nc\" not in data:\n raise SyntaxError(emojis(f\"{dataset} key missing ❌.\\n either 'names' or 'nc' are required in all data YAMLs.\"))\n if \"names\" in data and \"nc\" in data and len(data[\"names\"]) != data[\"nc\"]:\n raise SyntaxError(emojis(f\"{dataset} 'names' length {len(data['names'])} and 'nc: {data['nc']}' must match.\"))\n if \"names\" not in data:\n data[\"names\"] = [f\"class_{i}\" for i in range(data[\"nc\"])]\n else:\n data[\"nc\"] = len(data[\"names\"])\n\n data[\"names\"] = check_class_names(data[\"names\"])\n data[\"channels\"] = data.get(\"channels\", 3) # get image channels, default to 3\n\n # Resolve paths\n path = Path(extract_dir or data.get(\"path\") or Path(data.get(\"yaml_file\", \"\")).parent) # dataset root\n if not path.exists() and not path.is_absolute():\n path = (DATASETS_DIR / path).resolve() # path relative to DATASETS_DIR\n\n # Set paths\n data[\"path\"] = path # download scripts\n for k in \"train\", \"val\", \"test\", \"minival\":\n if data.get(k): # prepend path\n if isinstance(data[k], str):\n x = (path / data[k]).resolve()\n if not x.exists() and data[k].startswith(\"../\"):\n x = (path / data[k][3:]).resolve()\n data[k] = str(x)\n else:\n data[k] = [str((path / x).resolve()) for x in data[k]]\n\n # Parse YAML\n val, s = (data.get(x) for x in (\"val\", \"download\"))\n if val:\n val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path\n if not all(x.exists() for x in val):\n name = clean_url(dataset) # dataset name with URL auth stripped\n LOGGER.info(\"\")\n m = f\"Dataset '{name}' images not found, missing path '{[x for x in val if not x.exists()][0]}'\"\n if s and autodownload:\n LOGGER.warning(m)\n else:\n m += f\"\\nNote dataset download directory is '{DATASETS_DIR}'. You can update this in '{SETTINGS_FILE}'\"\n raise FileNotFoundError(m)\n t = time.time()\n r = None # success\n if s.startswith(\"http\") and s.endswith(\".zip\"): # URL\n safe_download(url=s, dir=DATASETS_DIR, delete=True)\n elif s.startswith(\"bash \"): # bash script\n LOGGER.info(f\"Running {s} ...\")\n r = os.system(s)\n else: # python script\n exec(s, {\"yaml\": data})\n dt = f\"({round(time.time() - t, 1)}s)\"\n s = f\"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}\" if r in {0, None} else f\"failure {dt} ❌\"\n LOGGER.info(f\"Dataset download {s}\\n\")\n check_font(\"Arial.ttf\" if is_ascii(data[\"names\"]) else \"Arial.Unicode.ttf\") # download fonts\n\n return data # dictionary", "chunk_type": "function", "name": "check_det_dataset", "file_path": "ultralytics\\ultralytics\\data\\utils.py", "start_line": 387, "end_line": 479, "start_col": 0, "end_col": 15, "parent_name": null, "docstring": "Download, verify, and/or unzip a dataset if not found locally.\n\nThis function checks the availability of a specified dataset, and if not found, it has the option to download and\nunzip the dataset. It then reads and parses the accompanying YAML data, ensuring key requirements are met and also\nresolves paths related to the dataset.\n\nArgs:\n dataset (str): Path to the dataset or dataset descriptor (like a YAML file).\n autodownload (bool, optional): Whether to automatically download the dataset if not found.\n\nReturns:\n (Dict[str, Any]): Parsed dataset information and paths.", "parameters": [ "dataset: str", "autodownload: bool" ], "return_type": "Dict[str, Any]", "decorators": [], "complexity_score": 24, "dependencies": [ "json", "os", "random", "subprocess", "time", "zipfile", "multiprocessing.pool.ThreadPool", "pathlib.Path", "tarfile.is_tarfile", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "typing.Union", "cv2", "numpy", "PIL.Image", "PIL.ImageOps", "ultralytics.nn.autobackend.check_class_names", "ultralytics.utils.DATASETS_DIR", "ultralytics.utils.LOGGER", "ultralytics.utils.MACOS", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.ROOT", "ultralytics.utils.SETTINGS_FILE", "ultralytics.utils.TQDM", "ultralytics.utils.YAML", "ultralytics.utils.clean_url", "ultralytics.utils.colorstr", "ultralytics.utils.emojis", "ultralytics.utils.is_dir_writeable", "ultralytics.utils.checks.check_file", "ultralytics.utils.checks.check_font", "ultralytics.utils.checks.is_ascii", "ultralytics.utils.downloads.download", "ultralytics.utils.downloads.safe_download", "ultralytics.utils.downloads.unzip_file", "ultralytics.utils.ops.segments2boxes", "matplotlib.pyplot", "ultralytics.utils.plotting.colors", "gc", "ultralytics.data.YOLODataset", "ultralytics.data.split.split_classify_dataset", "torchvision.datasets.ImageFolder", "ultralytics.data.YOLODataset" ], "chunk_id": "function_check_det_dataset_1177c5cb" }, { "content": "def check_cls_dataset(dataset: Union[str, Path], split: str = \"\") -> Dict[str, Any]:\n \"\"\"\n Check a classification dataset such as Imagenet.\n\n This function accepts a `dataset` name and attempts to retrieve the corresponding dataset information.\n If the dataset is not found locally, it attempts to download the dataset from the internet and save it locally.\n\n Args:\n dataset (str | Path): The name of the dataset.\n split (str, optional): The split of the dataset. Either 'val', 'test', or ''.\n\n Returns:\n (Dict[str, Any]): A dictionary containing the following keys:\n\n - 'train' (Path): The directory path containing the training set of the dataset.\n - 'val' (Path): The directory path containing the validation set of the dataset.\n - 'test' (Path): The directory path containing the test set of the dataset.\n - 'nc' (int): The number of classes in the dataset.\n - 'names' (Dict[int, str]): A dictionary of class names in the dataset.\n \"\"\"\n # Download (optional if dataset=https://file.zip is passed directly)\n if str(dataset).startswith((\"http:/\", \"https:/\")):\n dataset = safe_download(dataset, dir=DATASETS_DIR, unzip=True, delete=False)\n elif str(dataset).endswith((\".zip\", \".tar\", \".gz\")):\n file = check_file(dataset)\n dataset = safe_download(file, dir=DATASETS_DIR, unzip=True, delete=False)\n\n dataset = Path(dataset)\n data_dir = (dataset if dataset.is_dir() else (DATASETS_DIR / dataset)).resolve()\n if not data_dir.is_dir():\n LOGGER.info(\"\")\n LOGGER.warning(f\"Dataset not found, missing path {data_dir}, attempting download...\")\n t = time.time()\n if str(dataset) == \"imagenet\":\n subprocess.run(f\"bash {ROOT / 'data/scripts/get_imagenet.sh'}\", shell=True, check=True)\n else:\n url = f\"https://github.com/ultralytics/assets/releases/download/v0.0.0/{dataset}.zip\"\n download(url, dir=data_dir.parent)\n LOGGER.info(f\"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {colorstr('bold', data_dir)}\\n\")\n train_set = data_dir / \"train\"\n if not train_set.is_dir():\n LOGGER.warning(f\"Dataset 'split=train' not found at {train_set}\")\n image_files = list(data_dir.rglob(\"*.jpg\")) + list(data_dir.rglob(\"*.png\"))\n if image_files:\n from ultralytics.data.split import split_classify_dataset\n\n LOGGER.info(f\"Found {len(image_files)} images in subdirectories. Attempting to split...\")\n data_dir = split_classify_dataset(data_dir, train_ratio=0.8)\n train_set = data_dir / \"train\"\n else:\n LOGGER.error(f\"No images found in {data_dir} or its subdirectories.\")\n val_set = (\n data_dir / \"val\"\n if (data_dir / \"val\").exists()\n else data_dir / \"validation\"\n if (data_dir / \"validation\").exists()\n else data_dir / \"valid\"\n if (data_dir / \"valid\").exists()\n else None\n ) # data/test or data/val\n test_set = data_dir / \"test\" if (data_dir / \"test\").exists() else None # data/val or data/test\n if split == \"val\" and not val_set:\n LOGGER.warning(\"Dataset 'split=val' not found, using 'split=test' instead.\")\n val_set = test_set\n elif split == \"test\" and not test_set:\n LOGGER.warning(\"Dataset 'split=test' not found, using 'split=val' instead.\")\n test_set = val_set\n\n nc = len([x for x in (data_dir / \"train\").glob(\"*\") if x.is_dir()]) # number of classes\n names = [x.name for x in (data_dir / \"train\").iterdir() if x.is_dir()] # class names list\n names = dict(enumerate(sorted(names)))\n\n # Print to console\n for k, v in {\"train\": train_set, \"val\": val_set, \"test\": test_set}.items():\n prefix = f\"{colorstr(f'{k}:')} {v}...\"\n if v is None:\n LOGGER.info(prefix)\n else:\n files = [path for path in v.rglob(\"*.*\") if path.suffix[1:].lower() in IMG_FORMATS]\n nf = len(files) # number of files\n nd = len({file.parent for file in files}) # number of directories\n if nf == 0:\n if k == \"train\":\n raise FileNotFoundError(f\"{dataset} '{k}:' no training images found\")\n else:\n LOGGER.warning(f\"{prefix} found {nf} images in {nd} classes (no images found)\")\n elif nd != nc:\n LOGGER.error(f\"{prefix} found {nf} images in {nd} classes (requires {nc} classes, not {nd})\")\n else:\n LOGGER.info(f\"{prefix} found {nf} images in {nd} classes ✅ \")\n\n return {\"train\": train_set, \"val\": val_set, \"test\": test_set, \"nc\": nc, \"names\": names, \"channels\": 3}", "chunk_type": "function", "name": "check_cls_dataset", "file_path": "ultralytics\\ultralytics\\data\\utils.py", "start_line": 482, "end_line": 573, "start_col": 0, "end_col": 106, "parent_name": null, "docstring": "Check a classification dataset such as Imagenet.\n\nThis function accepts a `dataset` name and attempts to retrieve the corresponding dataset information.\nIf the dataset is not found locally, it attempts to download the dataset from the internet and save it locally.\n\nArgs:\n dataset (str | Path): The name of the dataset.\n split (str, optional): The split of the dataset. Either 'val', 'test', or ''.\n\nReturns:\n (Dict[str, Any]): A dictionary containing the following keys:\n\n - 'train' (Path): The directory path containing the training set of the dataset.\n - 'val' (Path): The directory path containing the validation set of the dataset.\n - 'test' (Path): The directory path containing the test set of the dataset.\n - 'nc' (int): The number of classes in the dataset.\n - 'names' (Dict[int, str]): A dictionary of class names in the dataset.", "parameters": [ "dataset: Union[str, Path]", "split: str" ], "return_type": "Dict[str, Any]", "decorators": [], "complexity_score": 18, "dependencies": [ "json", "os", "random", "subprocess", "time", "zipfile", "multiprocessing.pool.ThreadPool", "pathlib.Path", "tarfile.is_tarfile", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "typing.Union", "cv2", "numpy", "PIL.Image", "PIL.ImageOps", "ultralytics.nn.autobackend.check_class_names", "ultralytics.utils.DATASETS_DIR", "ultralytics.utils.LOGGER", "ultralytics.utils.MACOS", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.ROOT", "ultralytics.utils.SETTINGS_FILE", "ultralytics.utils.TQDM", "ultralytics.utils.YAML", "ultralytics.utils.clean_url", "ultralytics.utils.colorstr", "ultralytics.utils.emojis", "ultralytics.utils.is_dir_writeable", "ultralytics.utils.checks.check_file", "ultralytics.utils.checks.check_font", "ultralytics.utils.checks.is_ascii", "ultralytics.utils.downloads.download", "ultralytics.utils.downloads.safe_download", "ultralytics.utils.downloads.unzip_file", "ultralytics.utils.ops.segments2boxes", "matplotlib.pyplot", "ultralytics.utils.plotting.colors", "gc", "ultralytics.data.YOLODataset", "ultralytics.data.split.split_classify_dataset", "torchvision.datasets.ImageFolder", "ultralytics.data.YOLODataset" ], "chunk_id": "function_check_cls_dataset_72e0c14a" }, { "content": "class HUBDatasetStats:\n \"\"\"\n A class for generating HUB dataset JSON and `-hub` dataset directory.\n\n Args:\n path (str): Path to data.yaml or data.zip (with data.yaml inside data.zip).\n task (str): Dataset task. Options are 'detect', 'segment', 'pose', 'classify'.\n autodownload (bool): Attempt to download dataset if not found locally.\n\n Attributes:\n task (str): Dataset task type.\n hub_dir (Path): Directory path for HUB dataset files.\n im_dir (Path): Directory path for compressed images.\n stats (Dict): Statistics dictionary containing dataset information.\n data (Dict): Dataset configuration data.\n\n Methods:\n get_json: Return dataset JSON for Ultralytics HUB.\n process_images: Compress images for Ultralytics HUB.\n\n Note:\n Download *.zip files from https://github.com/ultralytics/hub/tree/main/example_datasets\n i.e. https://github.com/ultralytics/hub/raw/main/example_datasets/coco8.zip for coco8.zip.\n\n Examples:\n >>> from ultralytics.data.utils import HUBDatasetStats\n >>> stats = HUBDatasetStats(\"path/to/coco8.zip\", task=\"detect\") # detect dataset\n >>> stats = HUBDatasetStats(\"path/to/coco8-seg.zip\", task=\"segment\") # segment dataset\n >>> stats = HUBDatasetStats(\"path/to/coco8-pose.zip\", task=\"pose\") # pose dataset\n >>> stats = HUBDatasetStats(\"path/to/dota8.zip\", task=\"obb\") # OBB dataset\n >>> stats = HUBDatasetStats(\"path/to/imagenet10.zip\", task=\"classify\") # classification dataset\n >>> stats.get_json(save=True)\n >>> stats.process_images()\n \"\"\"\n\n def __init__(self, path: str = \"coco8.yaml\", task: str = \"detect\", autodownload: bool = False):\n \"\"\"Initialize class.\"\"\"\n path = Path(path).resolve()\n LOGGER.info(f\"Starting HUB dataset checks for {path}....\")\n\n self.task = task # detect, segment, pose, classify, obb\n if self.task == \"classify\":\n unzip_dir = unzip_file(path)\n data = check_cls_dataset(unzip_dir)\n data[\"path\"] = unzip_dir\n else: # detect, segment, pose, obb\n _, data_dir, yaml_path = self._unzip(Path(path))\n try:\n # Load YAML with checks\n data = YAML.load(yaml_path)\n data[\"path\"] = \"\" # strip path since YAML should be in dataset root for all HUB datasets\n YAML.save(yaml_path, data)\n data = check_det_dataset(yaml_path, autodownload) # dict\n data[\"path\"] = data_dir # YAML path should be set to '' (relative) or parent (absolute)\n except Exception as e:\n raise Exception(\"error/HUB/dataset_stats/init\") from e\n\n self.hub_dir = Path(f\"{data['path']}-hub\")\n self.im_dir = self.hub_dir / \"images\"\n self.stats = {\"nc\": len(data[\"names\"]), \"names\": list(data[\"names\"].values())} # statistics dictionary\n self.data = data\n\n @staticmethod\n def _unzip(path: Path) -> Tuple[bool, str, Path]:\n \"\"\"Unzip data.zip.\"\"\"\n if not str(path).endswith(\".zip\"): # path is data.yaml\n return False, None, path\n unzip_dir = unzip_file(path, path=path.parent)\n assert unzip_dir.is_dir(), (\n f\"Error unzipping {path}, {unzip_dir} not found. path/to/abc.zip MUST unzip to path/to/abc/\"\n )\n return True, str(unzip_dir), find_dataset_yaml(unzip_dir) # zipped, data_dir, yaml_path\n\n def _hub_ops(self, f: str):\n \"\"\"Save a compressed image for HUB previews.\"\"\"\n compress_one_image(f, self.im_dir / Path(f).name) # save to dataset-hub\n\n def get_json(self, save: bool = False, verbose: bool = False) -> Dict:\n \"\"\"Return dataset JSON for Ultralytics HUB.\"\"\"\n\n def _round(labels):\n \"\"\"Update labels to integer class and 4 decimal place floats.\"\"\"\n if self.task == \"detect\":\n coordinates = labels[\"bboxes\"]\n elif self.task in {\"segment\", \"obb\"}: # Segment and OBB use segments. OBB segments are normalized xyxyxyxy\n coordinates = [x.flatten() for x in labels[\"segments\"]]\n elif self.task == \"pose\":\n n, nk, nd = labels[\"keypoints\"].shape\n coordinates = np.concatenate((labels[\"bboxes\"], labels[\"keypoints\"].reshape(n, nk * nd)), 1)\n else:\n raise ValueError(f\"Undefined dataset task={self.task}.\")\n zipped = zip(labels[\"cls\"], coordinates)\n return [[int(c[0]), *(round(float(x), 4) for x in points)] for c, points in zipped]\n\n for split in \"train\", \"val\", \"test\":\n self.stats[split] = None # predefine\n path = self.data.get(split)\n\n # Check split\n if path is None: # no split\n continue\n files = [f for f in Path(path).rglob(\"*.*\") if f.suffix[1:].lower() in IMG_FORMATS] # image files in split\n if not files: # no images\n continue\n\n # Get dataset statistics\n if self.task == \"classify\":\n from torchvision.datasets import ImageFolder # scope for faster 'import ultralytics'\n\n dataset = ImageFolder(self.data[split])\n\n x = np.zeros(len(dataset.classes)).astype(int)\n for im in dataset.imgs:\n x[im[1]] += 1\n\n self.stats[split] = {\n \"instance_stats\": {\"total\": len(dataset), \"per_class\": x.tolist()},\n \"image_stats\": {\"total\": len(dataset), \"unlabelled\": 0, \"per_class\": x.tolist()},\n \"labels\": [{Path(k).name: v} for k, v in dataset.imgs],\n }\n else:\n from ultralytics.data import YOLODataset\n\n dataset = YOLODataset(img_path=self.data[split], data=self.data, task=self.task)\n x = np.array(\n [\n np.bincount(label[\"cls\"].astype(int).flatten(), minlength=self.data[\"nc\"])\n for label in TQDM(dataset.labels, total=len(dataset), desc=\"Statistics\")\n ]\n ) # shape(128x80)\n self.stats[split] = {\n \"instance_stats\": {\"total\": int(x.sum()), \"per_class\": x.sum(0).tolist()},\n \"image_stats\": {\n \"total\": len(dataset),\n \"unlabelled\": int(np.all(x == 0, 1).sum()),\n \"per_class\": (x > 0).sum(0).tolist(),\n },\n \"labels\": [{Path(k).name: _round(v)} for k, v in zip(dataset.im_files, dataset.labels)],\n }\n\n # Save, print and return\n if save:\n self.hub_dir.mkdir(parents=True, exist_ok=True) # makes dataset-hub/\n stats_path = self.hub_dir / \"stats.json\"\n LOGGER.info(f\"Saving {stats_path.resolve()}...\")\n with open(stats_path, \"w\", encoding=\"utf-8\") as f:\n json.dump(self.stats, f) # save stats.json\n if verbose:\n LOGGER.info(json.dumps(self.stats, indent=2, sort_keys=False))\n return self.stats\n\n def process_images(self) -> Path:\n \"\"\"Compress images for Ultralytics HUB.\"\"\"\n from ultralytics.data import YOLODataset # ClassificationDataset\n\n self.im_dir.mkdir(parents=True, exist_ok=True) # makes dataset-hub/images/\n for split in \"train\", \"val\", \"test\":\n if self.data.get(split) is None:\n continue\n dataset = YOLODataset(img_path=self.data[split], data=self.data)\n with ThreadPool(NUM_THREADS) as pool:\n for _ in TQDM(pool.imap(self._hub_ops, dataset.im_files), total=len(dataset), desc=f\"{split} images\"):\n pass\n LOGGER.info(f\"Done. All images saved to {self.im_dir}\")\n return self.im_dir", "chunk_type": "class", "name": "HUBDatasetStats", "file_path": "ultralytics\\ultralytics\\data\\utils.py", "start_line": 576, "end_line": 740, "start_col": 0, "end_col": 26, "parent_name": null, "docstring": "A class for generating HUB dataset JSON and `-hub` dataset directory.\n\nArgs:\n path (str): Path to data.yaml or data.zip (with data.yaml inside data.zip).\n task (str): Dataset task. Options are 'detect', 'segment', 'pose', 'classify'.\n autodownload (bool): Attempt to download dataset if not found locally.\n\nAttributes:\n task (str): Dataset task type.\n hub_dir (Path): Directory path for HUB dataset files.\n im_dir (Path): Directory path for compressed images.\n stats (Dict): Statistics dictionary containing dataset information.\n data (Dict): Dataset configuration data.\n\nMethods:\n get_json: Return dataset JSON for Ultralytics HUB.\n process_images: Compress images for Ultralytics HUB.\n\nNote:\n Download *.zip files from https://github.com/ultralytics/hub/tree/main/example_datasets\n i.e. https://github.com/ultralytics/hub/raw/main/example_datasets/coco8.zip for coco8.zip.\n\nExamples:\n >>> from ultralytics.data.utils import HUBDatasetStats\n >>> stats = HUBDatasetStats(\"path/to/coco8.zip\", task=\"detect\") # detect dataset\n >>> stats = HUBDatasetStats(\"path/to/coco8-seg.zip\", task=\"segment\") # segment dataset\n >>> stats = HUBDatasetStats(\"path/to/coco8-pose.zip\", task=\"pose\") # pose dataset\n >>> stats = HUBDatasetStats(\"path/to/dota8.zip\", task=\"obb\") # OBB dataset\n >>> stats = HUBDatasetStats(\"path/to/imagenet10.zip\", task=\"classify\") # classification dataset\n >>> stats.get_json(save=True)\n >>> stats.process_images()", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "json", "os", "random", "subprocess", "time", "zipfile", "multiprocessing.pool.ThreadPool", "pathlib.Path", "tarfile.is_tarfile", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "typing.Union", "cv2", "numpy", "PIL.Image", "PIL.ImageOps", "ultralytics.nn.autobackend.check_class_names", "ultralytics.utils.DATASETS_DIR", "ultralytics.utils.LOGGER", "ultralytics.utils.MACOS", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.ROOT", "ultralytics.utils.SETTINGS_FILE", "ultralytics.utils.TQDM", "ultralytics.utils.YAML", "ultralytics.utils.clean_url", "ultralytics.utils.colorstr", "ultralytics.utils.emojis", "ultralytics.utils.is_dir_writeable", "ultralytics.utils.checks.check_file", "ultralytics.utils.checks.check_font", "ultralytics.utils.checks.is_ascii", "ultralytics.utils.downloads.download", "ultralytics.utils.downloads.safe_download", "ultralytics.utils.downloads.unzip_file", "ultralytics.utils.ops.segments2boxes", "matplotlib.pyplot", "ultralytics.utils.plotting.colors", "gc", "ultralytics.data.YOLODataset", "ultralytics.data.split.split_classify_dataset", "torchvision.datasets.ImageFolder", "ultralytics.data.YOLODataset" ], "chunk_id": "class_HUBDatasetStats_3d96d042" }, { "content": "def compress_one_image(f: str, f_new: str = None, max_dim: int = 1920, quality: int = 50):\n \"\"\"\n Compress a single image file to reduced size while preserving its aspect ratio and quality using either the Python\n Imaging Library (PIL) or OpenCV library. If the input image is smaller than the maximum dimension, it will not be\n resized.\n\n Args:\n f (str): The path to the input image file.\n f_new (str, optional): The path to the output image file. If not specified, the input file will be overwritten.\n max_dim (int, optional): The maximum dimension (width or height) of the output image.\n quality (int, optional): The image compression quality as a percentage.\n\n Examples:\n >>> from pathlib import Path\n >>> from ultralytics.data.utils import compress_one_image\n >>> for f in Path(\"path/to/dataset\").rglob(\"*.jpg\"):\n >>> compress_one_image(f)\n \"\"\"\n try: # use PIL\n Image.MAX_IMAGE_PIXELS = None # Fix DecompressionBombError, allow optimization of image > ~178.9 million pixels\n im = Image.open(f)\n if im.mode in {\"RGBA\", \"LA\"}: # Convert to RGB if needed (for JPEG)\n im = im.convert(\"RGB\")\n r = max_dim / max(im.height, im.width) # ratio\n if r < 1.0: # image too large\n im = im.resize((int(im.width * r), int(im.height * r)))\n im.save(f_new or f, \"JPEG\", quality=quality, optimize=True) # save\n except Exception as e: # use OpenCV\n LOGGER.warning(f\"HUB ops PIL failure {f}: {e}\")\n im = cv2.imread(f)\n im_height, im_width = im.shape[:2]\n r = max_dim / max(im_height, im_width) # ratio\n if r < 1.0: # image too large\n im = cv2.resize(im, (int(im_width * r), int(im_height * r)), interpolation=cv2.INTER_AREA)\n cv2.imwrite(str(f_new or f), im)", "chunk_type": "function", "name": "compress_one_image", "file_path": "ultralytics\\ultralytics\\data\\utils.py", "start_line": 743, "end_line": 777, "start_col": 0, "end_col": 40, "parent_name": null, "docstring": "Compress a single image file to reduced size while preserving its aspect ratio and quality using either the Python\nImaging Library (PIL) or OpenCV library. If the input image is smaller than the maximum dimension, it will not be\nresized.\n\nArgs:\n f (str): The path to the input image file.\n f_new (str, optional): The path to the output image file. If not specified, the input file will be overwritten.\n max_dim (int, optional): The maximum dimension (width or height) of the output image.\n quality (int, optional): The image compression quality as a percentage.\n\nExamples:\n >>> from pathlib import Path\n >>> from ultralytics.data.utils import compress_one_image\n >>> for f in Path(\"path/to/dataset\").rglob(\"*.jpg\"):\n >>> compress_one_image(f)", "parameters": [ "f: str", "f_new: str", "max_dim: int", "quality: int" ], "return_type": null, "decorators": [], "complexity_score": 5, "dependencies": [ "json", "os", "random", "subprocess", "time", "zipfile", "multiprocessing.pool.ThreadPool", "pathlib.Path", "tarfile.is_tarfile", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "typing.Union", "cv2", "numpy", "PIL.Image", "PIL.ImageOps", "ultralytics.nn.autobackend.check_class_names", "ultralytics.utils.DATASETS_DIR", "ultralytics.utils.LOGGER", "ultralytics.utils.MACOS", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.ROOT", "ultralytics.utils.SETTINGS_FILE", "ultralytics.utils.TQDM", "ultralytics.utils.YAML", "ultralytics.utils.clean_url", "ultralytics.utils.colorstr", "ultralytics.utils.emojis", "ultralytics.utils.is_dir_writeable", "ultralytics.utils.checks.check_file", "ultralytics.utils.checks.check_font", "ultralytics.utils.checks.is_ascii", "ultralytics.utils.downloads.download", "ultralytics.utils.downloads.safe_download", "ultralytics.utils.downloads.unzip_file", "ultralytics.utils.ops.segments2boxes", "matplotlib.pyplot", "ultralytics.utils.plotting.colors", "gc", "ultralytics.data.YOLODataset", "ultralytics.data.split.split_classify_dataset", "torchvision.datasets.ImageFolder", "ultralytics.data.YOLODataset" ], "chunk_id": "function_compress_one_image_85d678e5" }, { "content": "def load_dataset_cache_file(path: Path) -> Dict:\n \"\"\"Load an Ultralytics *.cache dictionary from path.\"\"\"\n import gc\n\n gc.disable() # reduce pickle load time https://github.com/ultralytics/ultralytics/pull/1585\n cache = np.load(str(path), allow_pickle=True).item() # load dict\n gc.enable()\n return cache", "chunk_type": "function", "name": "load_dataset_cache_file", "file_path": "ultralytics\\ultralytics\\data\\utils.py", "start_line": 780, "end_line": 787, "start_col": 0, "end_col": 16, "parent_name": null, "docstring": "Load an Ultralytics *.cache dictionary from path.", "parameters": [ "path: Path" ], "return_type": "Dict", "decorators": [], "complexity_score": 1, "dependencies": [ "json", "os", "random", "subprocess", "time", "zipfile", "multiprocessing.pool.ThreadPool", "pathlib.Path", "tarfile.is_tarfile", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "typing.Union", "cv2", "numpy", "PIL.Image", "PIL.ImageOps", "ultralytics.nn.autobackend.check_class_names", "ultralytics.utils.DATASETS_DIR", "ultralytics.utils.LOGGER", "ultralytics.utils.MACOS", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.ROOT", "ultralytics.utils.SETTINGS_FILE", "ultralytics.utils.TQDM", "ultralytics.utils.YAML", "ultralytics.utils.clean_url", "ultralytics.utils.colorstr", "ultralytics.utils.emojis", "ultralytics.utils.is_dir_writeable", "ultralytics.utils.checks.check_file", "ultralytics.utils.checks.check_font", "ultralytics.utils.checks.is_ascii", "ultralytics.utils.downloads.download", "ultralytics.utils.downloads.safe_download", "ultralytics.utils.downloads.unzip_file", "ultralytics.utils.ops.segments2boxes", "matplotlib.pyplot", "ultralytics.utils.plotting.colors", "gc", "ultralytics.data.YOLODataset", "ultralytics.data.split.split_classify_dataset", "torchvision.datasets.ImageFolder", "ultralytics.data.YOLODataset" ], "chunk_id": "function_load_dataset_cache_file_a1840500" }, { "content": "def save_dataset_cache_file(prefix: str, path: Path, x: Dict, version: str):\n \"\"\"Save an Ultralytics dataset *.cache dictionary x to path.\"\"\"\n x[\"version\"] = version # add cache version\n if is_dir_writeable(path.parent):\n if path.exists():\n path.unlink() # remove *.cache file if exists\n with open(str(path), \"wb\") as file: # context manager here fixes windows async np.save bug\n np.save(file, x)\n LOGGER.info(f\"{prefix}New cache created: {path}\")\n else:\n LOGGER.warning(f\"{prefix}Cache directory {path.parent} is not writeable, cache not saved.\")", "chunk_type": "function", "name": "save_dataset_cache_file", "file_path": "ultralytics\\ultralytics\\data\\utils.py", "start_line": 790, "end_line": 800, "start_col": 0, "end_col": 99, "parent_name": null, "docstring": "Save an Ultralytics dataset *.cache dictionary x to path.", "parameters": [ "prefix: str", "path: Path", "x: Dict", "version: str" ], "return_type": null, "decorators": [], "complexity_score": 3, "dependencies": [ "json", "os", "random", "subprocess", "time", "zipfile", "multiprocessing.pool.ThreadPool", "pathlib.Path", "tarfile.is_tarfile", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "typing.Union", "cv2", "numpy", "PIL.Image", "PIL.ImageOps", "ultralytics.nn.autobackend.check_class_names", "ultralytics.utils.DATASETS_DIR", "ultralytics.utils.LOGGER", "ultralytics.utils.MACOS", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.ROOT", "ultralytics.utils.SETTINGS_FILE", "ultralytics.utils.TQDM", "ultralytics.utils.YAML", "ultralytics.utils.clean_url", "ultralytics.utils.colorstr", "ultralytics.utils.emojis", "ultralytics.utils.is_dir_writeable", "ultralytics.utils.checks.check_file", "ultralytics.utils.checks.check_font", "ultralytics.utils.checks.is_ascii", "ultralytics.utils.downloads.download", "ultralytics.utils.downloads.safe_download", "ultralytics.utils.downloads.unzip_file", "ultralytics.utils.ops.segments2boxes", "matplotlib.pyplot", "ultralytics.utils.plotting.colors", "gc", "ultralytics.data.YOLODataset", "ultralytics.data.split.split_classify_dataset", "torchvision.datasets.ImageFolder", "ultralytics.data.YOLODataset" ], "chunk_id": "function_save_dataset_cache_file_bc52abb0" }, { "content": "from .base import BaseDataset", "chunk_type": "import", "name": "BaseDataset", "file_path": "ultralytics\\ultralytics\\data\\__init__.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 29, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_BaseDataset_ee30aa3a" }, { "content": "from .build import build_dataloader, build_grounding, build_yolo_dataset, load_inference_source", "chunk_type": "import", "name": "build_dataloader, build_grounding, build_yolo_dataset, load_inference_source", "file_path": "ultralytics\\ultralytics\\data\\__init__.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 95, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_build_dataloader, build_grounding, build_yolo_dataset, load_inference_source_f7c9c66b" }, { "content": "from .dataset import (\n ClassificationDataset,\n GroundingDataset,\n SemanticDataset,\n YOLOConcatDataset,\n YOLODataset,\n YOLOMultiModalDataset,\n)", "chunk_type": "import", "name": "ClassificationDataset, GroundingDataset, SemanticDataset, YOLOConcatDataset, YOLODataset, YOLOMultiModalDataset", "file_path": "ultralytics\\ultralytics\\data\\__init__.py", "start_line": 5, "end_line": 12, "start_col": 0, "end_col": 1, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_ClassificationDataset, GroundingDataset, SemanticDataset, YOLOConcatDataset, YOLODataset, YOLOMultiModalDataset_58db4bf9" }, { "content": "__all__ = (\n \"BaseDataset\",\n \"ClassificationDataset\",\n \"SemanticDataset\",\n \"YOLODataset\",\n \"YOLOMultiModalDataset\",\n \"YOLOConcatDataset\",\n \"GroundingDataset\",\n \"build_yolo_dataset\",\n \"build_grounding\",\n \"build_dataloader\",\n \"load_inference_source\",\n)", "chunk_type": "variable", "name": "__all__", "file_path": "ultralytics\\ultralytics\\data\\__init__.py", "start_line": 14, "end_line": 26, "start_col": 0, "end_col": 1, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable___all___327a94b6" }, { "content": "import json", "chunk_type": "import", "name": "json", "file_path": "ultralytics\\ultralytics\\engine\\exporter.py", "start_line": 58, "end_line": 58, "start_col": 0, "end_col": 11, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_json_be9f436b" }, { "content": "import os", "chunk_type": "import", "name": "os", "file_path": "ultralytics\\ultralytics\\engine\\exporter.py", "start_line": 59, "end_line": 59, "start_col": 0, "end_col": 9, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_os_5603b82e" }, { "content": "import re", "chunk_type": "import", "name": "re", "file_path": "ultralytics\\ultralytics\\engine\\exporter.py", "start_line": 60, "end_line": 60, "start_col": 0, "end_col": 9, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_re_c9557fe1" }, { "content": "import shutil", "chunk_type": "import", "name": "shutil", "file_path": "ultralytics\\ultralytics\\engine\\exporter.py", "start_line": 61, "end_line": 61, "start_col": 0, "end_col": 13, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_shutil_6b708d64" }, { "content": "import subprocess", "chunk_type": "import", "name": "subprocess", "file_path": "ultralytics\\ultralytics\\engine\\exporter.py", "start_line": 62, "end_line": 62, "start_col": 0, "end_col": 17, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_subprocess_e40a3173" }, { "content": "import time", "chunk_type": "import", "name": "time", "file_path": "ultralytics\\ultralytics\\engine\\exporter.py", "start_line": 63, "end_line": 63, "start_col": 0, "end_col": 11, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_time_49f31c44" }, { "content": "import warnings", "chunk_type": "import", "name": "warnings", "file_path": "ultralytics\\ultralytics\\engine\\exporter.py", "start_line": 64, "end_line": 64, "start_col": 0, "end_col": 15, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_warnings_d6635558" }, { "content": "from copy import deepcopy", "chunk_type": "import", "name": "deepcopy", "file_path": "ultralytics\\ultralytics\\engine\\exporter.py", "start_line": 65, "end_line": 65, "start_col": 0, "end_col": 25, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_deepcopy_41b96b46" }, { "content": "from datetime import datetime", "chunk_type": "import", "name": "datetime", "file_path": "ultralytics\\ultralytics\\engine\\exporter.py", "start_line": 66, "end_line": 66, "start_col": 0, "end_col": 29, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_datetime_21e29a74" }, { "content": "from pathlib import Path", "chunk_type": "import", "name": "Path", "file_path": "ultralytics\\ultralytics\\engine\\exporter.py", "start_line": 67, "end_line": 67, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Path_2e3efd42" }, { "content": "import numpy as np", "chunk_type": "import", "name": "numpy", "file_path": "ultralytics\\ultralytics\\engine\\exporter.py", "start_line": 69, "end_line": 69, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_numpy_ddaf2c7c" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\ultralytics\\engine\\exporter.py", "start_line": 70, "end_line": 70, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_6ad6debc" }, { "content": "from ultralytics import __version__", "chunk_type": "import", "name": "__version__", "file_path": "ultralytics\\ultralytics\\engine\\exporter.py", "start_line": 72, "end_line": 72, "start_col": 0, "end_col": 35, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import___version___1728ed20" }, { "content": "from ultralytics.cfg import TASK2DATA, get_cfg", "chunk_type": "import", "name": "TASK2DATA, get_cfg", "file_path": "ultralytics\\ultralytics\\engine\\exporter.py", "start_line": 73, "end_line": 73, "start_col": 0, "end_col": 46, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_TASK2DATA, get_cfg_2e985fa9" }, { "content": "from ultralytics.data import build_dataloader", "chunk_type": "import", "name": "build_dataloader", "file_path": "ultralytics\\ultralytics\\engine\\exporter.py", "start_line": 74, "end_line": 74, "start_col": 0, "end_col": 45, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_build_dataloader_d35ec4a7" }, { "content": "from ultralytics.data.dataset import YOLODataset", "chunk_type": "import", "name": "YOLODataset", "file_path": "ultralytics\\ultralytics\\engine\\exporter.py", "start_line": 75, "end_line": 75, "start_col": 0, "end_col": 48, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_YOLODataset_7eb60a14" }, { "content": "from ultralytics.data.utils import check_cls_dataset, check_det_dataset", "chunk_type": "import", "name": "check_cls_dataset, check_det_dataset", "file_path": "ultralytics\\ultralytics\\engine\\exporter.py", "start_line": 76, "end_line": 76, "start_col": 0, "end_col": 71, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_check_cls_dataset, check_det_dataset_195ea01e" }, { "content": "from ultralytics.nn.autobackend import check_class_names, default_class_names", "chunk_type": "import", "name": "check_class_names, default_class_names", "file_path": "ultralytics\\ultralytics\\engine\\exporter.py", "start_line": 77, "end_line": 77, "start_col": 0, "end_col": 77, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_check_class_names, default_class_names_96d9ac3f" }, { "content": "from ultralytics.nn.modules import C2f, Classify, Detect, RTDETRDecoder", "chunk_type": "import", "name": "C2f, Classify, Detect, RTDETRDecoder", "file_path": "ultralytics\\ultralytics\\engine\\exporter.py", "start_line": 78, "end_line": 78, "start_col": 0, "end_col": 71, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_C2f, Classify, Detect, RTDETRDecoder_46633000" }, { "content": "from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, WorldModel", "chunk_type": "import", "name": "ClassificationModel, DetectionModel, SegmentationModel, WorldModel", "file_path": "ultralytics\\ultralytics\\engine\\exporter.py", "start_line": 79, "end_line": 79, "start_col": 0, "end_col": 99, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_ClassificationModel, DetectionModel, SegmentationModel, WorldModel_5fce3d67" }, { "content": "from ultralytics.utils import (\n ARM64,\n DEFAULT_CFG,\n IS_COLAB,\n IS_JETSON,\n LINUX,\n LOGGER,\n MACOS,\n MACOS_VERSION,\n RKNN_CHIPS,\n ROOT,\n SETTINGS,\n WINDOWS,\n YAML,\n callbacks,\n colorstr,\n get_default_args,\n)", "chunk_type": "import", "name": "ARM64, DEFAULT_CFG, IS_COLAB, IS_JETSON, LINUX, LOGGER, MACOS, MACOS_VERSION, RKNN_CHIPS, ROOT, SETTINGS, WINDOWS, YAML, callbacks, colorstr, get_default_args", "file_path": "ultralytics\\ultralytics\\engine\\exporter.py", "start_line": 80, "end_line": 97, "start_col": 0, "end_col": 1, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_ARM64, DEFAULT_CFG, IS_COLAB, IS_JETSON, LINUX, LOGGER, MACOS, MACOS_VERSION, RKNN_CHIPS, ROOT, SETTINGS, WINDOWS, YAML, callbacks, colorstr, get_default_args_90d4bf73" }, { "content": "from ultralytics.utils.checks import (\n check_imgsz,\n check_is_path_safe,\n check_requirements,\n check_version,\n is_intel,\n is_sudo_available,\n)", "chunk_type": "import", "name": "check_imgsz, check_is_path_safe, check_requirements, check_version, is_intel, is_sudo_available", "file_path": "ultralytics\\ultralytics\\engine\\exporter.py", "start_line": 98, "end_line": 105, "start_col": 0, "end_col": 1, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_check_imgsz, check_is_path_safe, check_requirements, check_version, is_intel, is_sudo_available_21963157" }, { "content": "from ultralytics.utils.downloads import attempt_download_asset, get_github_assets, safe_download", "chunk_type": "import", "name": "attempt_download_asset, get_github_assets, safe_download", "file_path": "ultralytics\\ultralytics\\engine\\exporter.py", "start_line": 106, "end_line": 106, "start_col": 0, "end_col": 96, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_attempt_download_asset, get_github_assets, safe_download_6da5da63" }, { "content": "from ultralytics.utils.export import export_engine, export_onnx", "chunk_type": "import", "name": "export_engine, export_onnx", "file_path": "ultralytics\\ultralytics\\engine\\exporter.py", "start_line": 107, "end_line": 107, "start_col": 0, "end_col": 63, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_export_engine, export_onnx_c54fcfc8" }, { "content": "from ultralytics.utils.files import file_size, spaces_in_path", "chunk_type": "import", "name": "file_size, spaces_in_path", "file_path": "ultralytics\\ultralytics\\engine\\exporter.py", "start_line": 108, "end_line": 108, "start_col": 0, "end_col": 61, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_file_size, spaces_in_path_7ab59a8a" }, { "content": "from ultralytics.utils.ops import Profile, nms_rotated", "chunk_type": "import", "name": "Profile, nms_rotated", "file_path": "ultralytics\\ultralytics\\engine\\exporter.py", "start_line": 109, "end_line": 109, "start_col": 0, "end_col": 54, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Profile, nms_rotated_6efcf045" }, { "content": "from ultralytics.utils.patches import arange_patch", "chunk_type": "import", "name": "arange_patch", "file_path": "ultralytics\\ultralytics\\engine\\exporter.py", "start_line": 110, "end_line": 110, "start_col": 0, "end_col": 50, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_arange_patch_23a97ecd" }, { "content": "from ultralytics.utils.torch_utils import TORCH_1_13, get_latest_opset, select_device", "chunk_type": "import", "name": "TORCH_1_13, get_latest_opset, select_device", "file_path": "ultralytics\\ultralytics\\engine\\exporter.py", "start_line": 111, "end_line": 111, "start_col": 0, "end_col": 85, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_TORCH_1_13, get_latest_opset, select_device_704bc770" }, { "content": "def export_formats():\n \"\"\"Return a dictionary of Ultralytics YOLO export formats.\"\"\"\n x = [\n [\"PyTorch\", \"-\", \".pt\", True, True, []],\n [\"TorchScript\", \"torchscript\", \".torchscript\", True, True, [\"batch\", \"optimize\", \"half\", \"nms\"]],\n [\"ONNX\", \"onnx\", \".onnx\", True, True, [\"batch\", \"dynamic\", \"half\", \"opset\", \"simplify\", \"nms\"]],\n [\n \"OpenVINO\",\n \"openvino\",\n \"_openvino_model\",\n True,\n False,\n [\"batch\", \"dynamic\", \"half\", \"int8\", \"nms\", \"fraction\"],\n ],\n [\n \"TensorRT\",\n \"engine\",\n \".engine\",\n False,\n True,\n [\"batch\", \"dynamic\", \"half\", \"int8\", \"simplify\", \"nms\", \"fraction\"],\n ],\n [\"CoreML\", \"coreml\", \".mlpackage\", True, False, [\"batch\", \"half\", \"int8\", \"nms\"]],\n [\"TensorFlow SavedModel\", \"saved_model\", \"_saved_model\", True, True, [\"batch\", \"int8\", \"keras\", \"nms\"]],\n [\"TensorFlow GraphDef\", \"pb\", \".pb\", True, True, [\"batch\"]],\n [\"TensorFlow Lite\", \"tflite\", \".tflite\", True, False, [\"batch\", \"half\", \"int8\", \"nms\", \"fraction\"]],\n [\"TensorFlow Edge TPU\", \"edgetpu\", \"_edgetpu.tflite\", True, False, []],\n [\"TensorFlow.js\", \"tfjs\", \"_web_model\", True, False, [\"batch\", \"half\", \"int8\", \"nms\"]],\n [\"PaddlePaddle\", \"paddle\", \"_paddle_model\", True, True, [\"batch\"]],\n [\"MNN\", \"mnn\", \".mnn\", True, True, [\"batch\", \"half\", \"int8\"]],\n [\"NCNN\", \"ncnn\", \"_ncnn_model\", True, True, [\"batch\", \"half\"]],\n [\"IMX\", \"imx\", \"_imx_model\", True, True, [\"int8\", \"fraction\", \"nms\"]],\n [\"RKNN\", \"rknn\", \"_rknn_model\", False, False, [\"batch\", \"name\"]],\n ]\n return dict(zip([\"Format\", \"Argument\", \"Suffix\", \"CPU\", \"GPU\", \"Arguments\"], zip(*x)))", "chunk_type": "function", "name": "export_formats", "file_path": "ultralytics\\ultralytics\\engine\\exporter.py", "start_line": 114, "end_line": 148, "start_col": 0, "end_col": 90, "parent_name": null, "docstring": "Return a dictionary of Ultralytics YOLO export formats.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "json", "os", "re", "shutil", "subprocess", "time", "warnings", "copy.deepcopy", "datetime.datetime", "pathlib.Path", "numpy", "torch", "ultralytics.__version__", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.get_cfg", "ultralytics.data.build_dataloader", "ultralytics.data.dataset.YOLODataset", "ultralytics.data.utils.check_cls_dataset", "ultralytics.data.utils.check_det_dataset", "ultralytics.nn.autobackend.check_class_names", "ultralytics.nn.autobackend.default_class_names", "ultralytics.nn.modules.C2f", "ultralytics.nn.modules.Classify", "ultralytics.nn.modules.Detect", "ultralytics.nn.modules.RTDETRDecoder", "ultralytics.nn.tasks.ClassificationModel", "ultralytics.nn.tasks.DetectionModel", "ultralytics.nn.tasks.SegmentationModel", "ultralytics.nn.tasks.WorldModel", "ultralytics.utils.ARM64", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.IS_COLAB", "ultralytics.utils.IS_JETSON", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.MACOS", "ultralytics.utils.MACOS_VERSION", "ultralytics.utils.RKNN_CHIPS", "ultralytics.utils.ROOT", "ultralytics.utils.SETTINGS", "ultralytics.utils.WINDOWS", "ultralytics.utils.YAML", "ultralytics.utils.callbacks", "ultralytics.utils.colorstr", "ultralytics.utils.get_default_args", "ultralytics.utils.checks.check_imgsz", "ultralytics.utils.checks.check_is_path_safe", "ultralytics.utils.checks.check_requirements", "ultralytics.utils.checks.check_version", "ultralytics.utils.checks.is_intel", "ultralytics.utils.checks.is_sudo_available", "ultralytics.utils.downloads.attempt_download_asset", "ultralytics.utils.downloads.get_github_assets", "ultralytics.utils.downloads.safe_download", "ultralytics.utils.export.export_engine", "ultralytics.utils.export.export_onnx", "ultralytics.utils.files.file_size", "ultralytics.utils.files.spaces_in_path", "ultralytics.utils.ops.Profile", "ultralytics.utils.ops.nms_rotated", "ultralytics.utils.patches.arange_patch", "ultralytics.utils.torch_utils.TORCH_1_13", "ultralytics.utils.torch_utils.get_latest_opset", "ultralytics.utils.torch_utils.select_device", "onnx", "openvino", "x2paddle", "x2paddle.convert.pytorch2paddle", "MNN", "MNN.tools.mnnconvert", "ncnn", "coremltools", "onnx2tf", "tensorflow", "tensorflow.python.framework.convert_to_constants.convert_variables_to_constants_v2", "tensorflow", "tensorflow", "tensorflowjs", "rknn.api.RKNN", "model_compression_toolkit", "onnx", "edgemdt_tpc.get_target_platform_capabilities", "sony_custom_layers.pytorch.multiclass_nms_with_indices", "zipfile", "coremltools", "functools.partial", "torchvision.ops.nms", "difflib", "ultralytics.utils.torch_utils.FXModel", "torch.utils.mobile_optimizer.optimize_for_mobile", "nncf", "tensorrt", "tensorflow", "builtins", "PIL.Image", "ultralytics.utils.tal.make_anchors", "onnxslim", "tensorrt", "tensorflow", "coremltools.optimize.coreml" ], "chunk_id": "function_export_formats_909f2d97" }, { "content": "def validate_args(format, passed_args, valid_args):\n \"\"\"\n Validate arguments based on the export format.\n\n Args:\n format (str): The export format.\n passed_args (Namespace): The arguments used during export.\n valid_args (list): List of valid arguments for the format.\n\n Raises:\n AssertionError: If an unsupported argument is used, or if the format lacks supported argument listings.\n \"\"\"\n export_args = [\"half\", \"int8\", \"dynamic\", \"keras\", \"nms\", \"batch\", \"fraction\"]\n\n assert valid_args is not None, f\"ERROR ❌️ valid arguments for '{format}' not listed.\"\n custom = {\"batch\": 1, \"data\": None, \"device\": None} # exporter defaults\n default_args = get_cfg(DEFAULT_CFG, custom)\n for arg in export_args:\n not_default = getattr(passed_args, arg, None) != getattr(default_args, arg, None)\n if not_default:\n assert arg in valid_args, f\"ERROR ❌️ argument '{arg}' is not supported for format='{format}'\"", "chunk_type": "function", "name": "validate_args", "file_path": "ultralytics\\ultralytics\\engine\\exporter.py", "start_line": 151, "end_line": 171, "start_col": 0, "end_col": 109, "parent_name": null, "docstring": "Validate arguments based on the export format.\n\nArgs:\n format (str): The export format.\n passed_args (Namespace): The arguments used during export.\n valid_args (list): List of valid arguments for the format.\n\nRaises:\n AssertionError: If an unsupported argument is used, or if the format lacks supported argument listings.", "parameters": [ "format", "passed_args", "valid_args" ], "return_type": null, "decorators": [], "complexity_score": 3, "dependencies": [ "json", "os", "re", "shutil", "subprocess", "time", "warnings", "copy.deepcopy", "datetime.datetime", "pathlib.Path", "numpy", "torch", "ultralytics.__version__", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.get_cfg", "ultralytics.data.build_dataloader", "ultralytics.data.dataset.YOLODataset", "ultralytics.data.utils.check_cls_dataset", "ultralytics.data.utils.check_det_dataset", "ultralytics.nn.autobackend.check_class_names", "ultralytics.nn.autobackend.default_class_names", "ultralytics.nn.modules.C2f", "ultralytics.nn.modules.Classify", "ultralytics.nn.modules.Detect", "ultralytics.nn.modules.RTDETRDecoder", "ultralytics.nn.tasks.ClassificationModel", "ultralytics.nn.tasks.DetectionModel", "ultralytics.nn.tasks.SegmentationModel", "ultralytics.nn.tasks.WorldModel", "ultralytics.utils.ARM64", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.IS_COLAB", "ultralytics.utils.IS_JETSON", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.MACOS", "ultralytics.utils.MACOS_VERSION", "ultralytics.utils.RKNN_CHIPS", "ultralytics.utils.ROOT", "ultralytics.utils.SETTINGS", "ultralytics.utils.WINDOWS", "ultralytics.utils.YAML", "ultralytics.utils.callbacks", "ultralytics.utils.colorstr", "ultralytics.utils.get_default_args", "ultralytics.utils.checks.check_imgsz", "ultralytics.utils.checks.check_is_path_safe", "ultralytics.utils.checks.check_requirements", "ultralytics.utils.checks.check_version", "ultralytics.utils.checks.is_intel", "ultralytics.utils.checks.is_sudo_available", "ultralytics.utils.downloads.attempt_download_asset", "ultralytics.utils.downloads.get_github_assets", "ultralytics.utils.downloads.safe_download", "ultralytics.utils.export.export_engine", "ultralytics.utils.export.export_onnx", "ultralytics.utils.files.file_size", "ultralytics.utils.files.spaces_in_path", "ultralytics.utils.ops.Profile", "ultralytics.utils.ops.nms_rotated", "ultralytics.utils.patches.arange_patch", "ultralytics.utils.torch_utils.TORCH_1_13", "ultralytics.utils.torch_utils.get_latest_opset", "ultralytics.utils.torch_utils.select_device", "onnx", "openvino", "x2paddle", "x2paddle.convert.pytorch2paddle", "MNN", "MNN.tools.mnnconvert", "ncnn", "coremltools", "onnx2tf", "tensorflow", "tensorflow.python.framework.convert_to_constants.convert_variables_to_constants_v2", "tensorflow", "tensorflow", "tensorflowjs", "rknn.api.RKNN", "model_compression_toolkit", "onnx", "edgemdt_tpc.get_target_platform_capabilities", "sony_custom_layers.pytorch.multiclass_nms_with_indices", "zipfile", "coremltools", "functools.partial", "torchvision.ops.nms", "difflib", "ultralytics.utils.torch_utils.FXModel", "torch.utils.mobile_optimizer.optimize_for_mobile", "nncf", "tensorrt", "tensorflow", "builtins", "PIL.Image", "ultralytics.utils.tal.make_anchors", "onnxslim", "tensorrt", "tensorflow", "coremltools.optimize.coreml" ], "chunk_id": "function_validate_args_cbf366e5" }, { "content": "def gd_outputs(gd):\n \"\"\"Return TensorFlow GraphDef model output node names.\"\"\"\n name_list, input_list = [], []\n for node in gd.node: # tensorflow.core.framework.node_def_pb2.NodeDef\n name_list.append(node.name)\n input_list.extend(node.input)\n return sorted(f\"{x}:0\" for x in list(set(name_list) - set(input_list)) if not x.startswith(\"NoOp\"))", "chunk_type": "function", "name": "gd_outputs", "file_path": "ultralytics\\ultralytics\\engine\\exporter.py", "start_line": 174, "end_line": 180, "start_col": 0, "end_col": 103, "parent_name": null, "docstring": "Return TensorFlow GraphDef model output node names.", "parameters": [ "gd" ], "return_type": null, "decorators": [], "complexity_score": 3, "dependencies": [ "json", "os", "re", "shutil", "subprocess", "time", "warnings", "copy.deepcopy", "datetime.datetime", "pathlib.Path", "numpy", "torch", "ultralytics.__version__", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.get_cfg", "ultralytics.data.build_dataloader", "ultralytics.data.dataset.YOLODataset", "ultralytics.data.utils.check_cls_dataset", "ultralytics.data.utils.check_det_dataset", "ultralytics.nn.autobackend.check_class_names", "ultralytics.nn.autobackend.default_class_names", "ultralytics.nn.modules.C2f", "ultralytics.nn.modules.Classify", "ultralytics.nn.modules.Detect", "ultralytics.nn.modules.RTDETRDecoder", "ultralytics.nn.tasks.ClassificationModel", "ultralytics.nn.tasks.DetectionModel", "ultralytics.nn.tasks.SegmentationModel", "ultralytics.nn.tasks.WorldModel", "ultralytics.utils.ARM64", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.IS_COLAB", "ultralytics.utils.IS_JETSON", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.MACOS", "ultralytics.utils.MACOS_VERSION", "ultralytics.utils.RKNN_CHIPS", "ultralytics.utils.ROOT", "ultralytics.utils.SETTINGS", "ultralytics.utils.WINDOWS", "ultralytics.utils.YAML", "ultralytics.utils.callbacks", "ultralytics.utils.colorstr", "ultralytics.utils.get_default_args", "ultralytics.utils.checks.check_imgsz", "ultralytics.utils.checks.check_is_path_safe", "ultralytics.utils.checks.check_requirements", "ultralytics.utils.checks.check_version", "ultralytics.utils.checks.is_intel", "ultralytics.utils.checks.is_sudo_available", "ultralytics.utils.downloads.attempt_download_asset", "ultralytics.utils.downloads.get_github_assets", "ultralytics.utils.downloads.safe_download", "ultralytics.utils.export.export_engine", "ultralytics.utils.export.export_onnx", "ultralytics.utils.files.file_size", "ultralytics.utils.files.spaces_in_path", "ultralytics.utils.ops.Profile", "ultralytics.utils.ops.nms_rotated", "ultralytics.utils.patches.arange_patch", "ultralytics.utils.torch_utils.TORCH_1_13", "ultralytics.utils.torch_utils.get_latest_opset", "ultralytics.utils.torch_utils.select_device", "onnx", "openvino", "x2paddle", "x2paddle.convert.pytorch2paddle", "MNN", "MNN.tools.mnnconvert", "ncnn", "coremltools", "onnx2tf", "tensorflow", "tensorflow.python.framework.convert_to_constants.convert_variables_to_constants_v2", "tensorflow", "tensorflow", "tensorflowjs", "rknn.api.RKNN", "model_compression_toolkit", "onnx", "edgemdt_tpc.get_target_platform_capabilities", "sony_custom_layers.pytorch.multiclass_nms_with_indices", "zipfile", "coremltools", "functools.partial", "torchvision.ops.nms", "difflib", "ultralytics.utils.torch_utils.FXModel", "torch.utils.mobile_optimizer.optimize_for_mobile", "nncf", "tensorrt", "tensorflow", "builtins", "PIL.Image", "ultralytics.utils.tal.make_anchors", "onnxslim", "tensorrt", "tensorflow", "coremltools.optimize.coreml" ], "chunk_id": "function_gd_outputs_554738e3" }, { "content": "def try_export(inner_func):\n \"\"\"YOLO export decorator, i.e. @try_export.\"\"\"\n inner_args = get_default_args(inner_func)\n\n def outer_func(*args, **kwargs):\n \"\"\"Export a model.\"\"\"\n prefix = inner_args[\"prefix\"]\n dt = 0.0\n try:\n with Profile() as dt:\n f, model = inner_func(*args, **kwargs)\n LOGGER.info(f\"{prefix} export success ✅ {dt.t:.1f}s, saved as '{f}' ({file_size(f):.1f} MB)\")\n return f, model\n except Exception as e:\n LOGGER.error(f\"{prefix} export failure {dt.t:.1f}s: {e}\")\n raise e\n\n return outer_func", "chunk_type": "function", "name": "try_export", "file_path": "ultralytics\\ultralytics\\engine\\exporter.py", "start_line": 183, "end_line": 200, "start_col": 0, "end_col": 21, "parent_name": null, "docstring": "YOLO export decorator, i.e. @try_export.", "parameters": [ "inner_func" ], "return_type": null, "decorators": [], "complexity_score": 2, "dependencies": [ "json", "os", "re", "shutil", "subprocess", "time", "warnings", "copy.deepcopy", "datetime.datetime", "pathlib.Path", "numpy", "torch", "ultralytics.__version__", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.get_cfg", "ultralytics.data.build_dataloader", "ultralytics.data.dataset.YOLODataset", "ultralytics.data.utils.check_cls_dataset", "ultralytics.data.utils.check_det_dataset", "ultralytics.nn.autobackend.check_class_names", "ultralytics.nn.autobackend.default_class_names", "ultralytics.nn.modules.C2f", "ultralytics.nn.modules.Classify", "ultralytics.nn.modules.Detect", "ultralytics.nn.modules.RTDETRDecoder", "ultralytics.nn.tasks.ClassificationModel", "ultralytics.nn.tasks.DetectionModel", "ultralytics.nn.tasks.SegmentationModel", "ultralytics.nn.tasks.WorldModel", "ultralytics.utils.ARM64", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.IS_COLAB", "ultralytics.utils.IS_JETSON", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.MACOS", "ultralytics.utils.MACOS_VERSION", "ultralytics.utils.RKNN_CHIPS", "ultralytics.utils.ROOT", "ultralytics.utils.SETTINGS", "ultralytics.utils.WINDOWS", "ultralytics.utils.YAML", "ultralytics.utils.callbacks", "ultralytics.utils.colorstr", "ultralytics.utils.get_default_args", "ultralytics.utils.checks.check_imgsz", "ultralytics.utils.checks.check_is_path_safe", "ultralytics.utils.checks.check_requirements", "ultralytics.utils.checks.check_version", "ultralytics.utils.checks.is_intel", "ultralytics.utils.checks.is_sudo_available", "ultralytics.utils.downloads.attempt_download_asset", "ultralytics.utils.downloads.get_github_assets", "ultralytics.utils.downloads.safe_download", "ultralytics.utils.export.export_engine", "ultralytics.utils.export.export_onnx", "ultralytics.utils.files.file_size", "ultralytics.utils.files.spaces_in_path", "ultralytics.utils.ops.Profile", "ultralytics.utils.ops.nms_rotated", "ultralytics.utils.patches.arange_patch", "ultralytics.utils.torch_utils.TORCH_1_13", "ultralytics.utils.torch_utils.get_latest_opset", "ultralytics.utils.torch_utils.select_device", "onnx", "openvino", "x2paddle", "x2paddle.convert.pytorch2paddle", "MNN", "MNN.tools.mnnconvert", "ncnn", "coremltools", "onnx2tf", "tensorflow", "tensorflow.python.framework.convert_to_constants.convert_variables_to_constants_v2", "tensorflow", "tensorflow", "tensorflowjs", "rknn.api.RKNN", "model_compression_toolkit", "onnx", "edgemdt_tpc.get_target_platform_capabilities", "sony_custom_layers.pytorch.multiclass_nms_with_indices", "zipfile", "coremltools", "functools.partial", "torchvision.ops.nms", "difflib", "ultralytics.utils.torch_utils.FXModel", "torch.utils.mobile_optimizer.optimize_for_mobile", "nncf", "tensorrt", "tensorflow", "builtins", "PIL.Image", "ultralytics.utils.tal.make_anchors", "onnxslim", "tensorrt", "tensorflow", "coremltools.optimize.coreml" ], "chunk_id": "function_try_export_25d36689" }, { "content": "class Exporter:\n \"\"\"\n A class for exporting YOLO models to various formats.\n\n This class provides functionality to export YOLO models to different formats including ONNX, TensorRT, CoreML,\n TensorFlow, and others. It handles format validation, device selection, model preparation, and the actual export\n process for each supported format.\n\n Attributes:\n args (SimpleNamespace): Configuration arguments for the exporter.\n callbacks (dict): Dictionary of callback functions for different export events.\n im (torch.Tensor): Input tensor for model inference during export.\n model (torch.nn.Module): The YOLO model to be exported.\n file (Path): Path to the model file being exported.\n output_shape (tuple): Shape of the model output tensor(s).\n pretty_name (str): Formatted model name for display purposes.\n metadata (dict): Model metadata including description, author, version, etc.\n device (torch.device): Device on which the model is loaded.\n imgsz (tuple): Input image size for the model.\n\n Methods:\n __call__: Main export method that handles the export process.\n get_int8_calibration_dataloader: Build dataloader for INT8 calibration.\n export_torchscript: Export model to TorchScript format.\n export_onnx: Export model to ONNX format.\n export_openvino: Export model to OpenVINO format.\n export_paddle: Export model to PaddlePaddle format.\n export_mnn: Export model to MNN format.\n export_ncnn: Export model to NCNN format.\n export_coreml: Export model to CoreML format.\n export_engine: Export model to TensorRT format.\n export_saved_model: Export model to TensorFlow SavedModel format.\n export_pb: Export model to TensorFlow GraphDef format.\n export_tflite: Export model to TensorFlow Lite format.\n export_edgetpu: Export model to Edge TPU format.\n export_tfjs: Export model to TensorFlow.js format.\n export_rknn: Export model to RKNN format.\n export_imx: Export model to IMX format.\n\n Examples:\n Export a YOLOv8 model to ONNX format\n >>> from ultralytics.engine.exporter import Exporter\n >>> exporter = Exporter()\n >>> exporter(model=\"yolov8n.pt\") # exports to yolov8n.onnx\n\n Export with specific arguments\n >>> args = {\"format\": \"onnx\", \"dynamic\": True, \"half\": True}\n >>> exporter = Exporter(overrides=args)\n >>> exporter(model=\"yolov8n.pt\")\n \"\"\"\n\n def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):\n \"\"\"\n Initialize the Exporter class.\n\n Args:\n cfg (str, optional): Path to a configuration file.\n overrides (dict, optional): Configuration overrides.\n _callbacks (dict, optional): Dictionary of callback functions.\n \"\"\"\n self.args = get_cfg(cfg, overrides)\n self.callbacks = _callbacks or callbacks.get_default_callbacks()\n callbacks.add_integration_callbacks(self)\n\n def __call__(self, model=None) -> str:\n \"\"\"Return list of exported files/dirs after running callbacks.\"\"\"\n t = time.time()\n fmt = self.args.format.lower() # to lowercase\n if fmt in {\"tensorrt\", \"trt\"}: # 'engine' aliases\n fmt = \"engine\"\n if fmt in {\"mlmodel\", \"mlpackage\", \"mlprogram\", \"apple\", \"ios\", \"coreml\"}: # 'coreml' aliases\n fmt = \"coreml\"\n fmts_dict = export_formats()\n fmts = tuple(fmts_dict[\"Argument\"][1:]) # available export formats\n if fmt not in fmts:\n import difflib\n\n # Get the closest match if format is invalid\n matches = difflib.get_close_matches(fmt, fmts, n=1, cutoff=0.6) # 60% similarity required to match\n if not matches:\n raise ValueError(f\"Invalid export format='{fmt}'. Valid formats are {fmts}\")\n LOGGER.warning(f\"Invalid export format='{fmt}', updating to format='{matches[0]}'\")\n fmt = matches[0]\n flags = [x == fmt for x in fmts]\n if sum(flags) != 1:\n raise ValueError(f\"Invalid export format='{fmt}'. Valid formats are {fmts}\")\n (jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, mnn, ncnn, imx, rknn) = (\n flags # export booleans\n )\n\n is_tf_format = any((saved_model, pb, tflite, edgetpu, tfjs))\n\n # Device\n dla = None\n if engine and self.args.device is None:\n LOGGER.warning(\"TensorRT requires GPU export, automatically assigning device=0\")\n self.args.device = \"0\"\n if engine and \"dla\" in str(self.args.device): # convert int/list to str first\n dla = self.args.device.rsplit(\":\", 1)[-1]\n self.args.device = \"0\" # update device to \"0\"\n assert dla in {\"0\", \"1\"}, f\"Expected self.args.device='dla:0' or 'dla:1, but got {self.args.device}.\"\n if imx and self.args.device is None and torch.cuda.is_available():\n LOGGER.warning(\"Exporting on CPU while CUDA is available, setting device=0 for faster export on GPU.\")\n self.args.device = \"0\" # update device to \"0\"\n self.device = select_device(\"cpu\" if self.args.device is None else self.args.device)\n\n # Argument compatibility checks\n fmt_keys = fmts_dict[\"Arguments\"][flags.index(True) + 1]\n validate_args(fmt, self.args, fmt_keys)\n if imx:\n if not self.args.int8:\n LOGGER.warning(\"IMX export requires int8=True, setting int8=True.\")\n self.args.int8 = True\n if not self.args.nms:\n LOGGER.warning(\"IMX export requires nms=True, setting nms=True.\")\n self.args.nms = True\n if model.task not in {\"detect\", \"pose\"}:\n raise ValueError(\"IMX export only supported for detection and pose estimation models.\")\n if not hasattr(model, \"names\"):\n model.names = default_class_names()\n model.names = check_class_names(model.names)\n if self.args.half and self.args.int8:\n LOGGER.warning(\"half=True and int8=True are mutually exclusive, setting half=False.\")\n self.args.half = False\n if self.args.half and onnx and self.device.type == \"cpu\":\n LOGGER.warning(\"half=True only compatible with GPU export, i.e. use device=0\")\n self.args.half = False\n self.imgsz = check_imgsz(self.args.imgsz, stride=model.stride, min_dim=2) # check image size\n if self.args.optimize:\n assert not ncnn, \"optimize=True not compatible with format='ncnn', i.e. use optimize=False\"\n assert self.device.type == \"cpu\", \"optimize=True not compatible with cuda devices, i.e. use device='cpu'\"\n if rknn:\n if not self.args.name:\n LOGGER.warning(\n \"Rockchip RKNN export requires a missing 'name' arg for processor type. \"\n \"Using default name='rk3588'.\"\n )\n self.args.name = \"rk3588\"\n self.args.name = self.args.name.lower()\n assert self.args.name in RKNN_CHIPS, (\n f\"Invalid processor name '{self.args.name}' for Rockchip RKNN export. Valid names are {RKNN_CHIPS}.\"\n )\n if self.args.int8 and tflite:\n assert not getattr(model, \"end2end\", False), \"TFLite INT8 export not supported for end2end models.\"\n if self.args.nms:\n assert not isinstance(model, ClassificationModel), \"'nms=True' is not valid for classification models.\"\n assert not (tflite and ARM64 and LINUX), \"TFLite export with NMS unsupported on ARM64 Linux\"\n if getattr(model, \"end2end\", False):\n LOGGER.warning(\"'nms=True' is not available for end2end models. Forcing 'nms=False'.\")\n self.args.nms = False\n self.args.conf = self.args.conf or 0.25 # set conf default value for nms export\n if (engine or self.args.nms) and self.args.dynamic and self.args.batch == 1:\n LOGGER.warning(\n f\"'dynamic=True' model with '{'nms=True' if self.args.nms else 'format=engine'}' requires max batch size, i.e. 'batch=16'\"\n )\n if edgetpu:\n if not LINUX or ARM64:\n raise SystemError(\n \"Edge TPU export only supported on non-aarch64 Linux. See https://coral.ai/docs/edgetpu/compiler\"\n )\n elif self.args.batch != 1: # see github.com/ultralytics/ultralytics/pull/13420\n LOGGER.warning(\"Edge TPU export requires batch size 1, setting batch=1.\")\n self.args.batch = 1\n if isinstance(model, WorldModel):\n LOGGER.warning(\n \"YOLOWorld (original version) export is not supported to any format. \"\n \"YOLOWorldv2 models (i.e. 'yolov8s-worldv2.pt') only support export to \"\n \"(torchscript, onnx, openvino, engine, coreml) formats. \"\n \"See https://docs.ultralytics.com/models/yolo-world for details.\"\n )\n model.clip_model = None # openvino int8 export error: https://github.com/ultralytics/ultralytics/pull/18445\n if self.args.int8 and not self.args.data:\n self.args.data = DEFAULT_CFG.data or TASK2DATA[getattr(model, \"task\", \"detect\")] # assign default data\n LOGGER.warning(\n f\"INT8 export requires a missing 'data' arg for calibration. Using default 'data={self.args.data}'.\"\n )\n if tfjs and (ARM64 and LINUX):\n raise SystemError(\"TF.js exports are not currently supported on ARM64 Linux\")\n # Recommend OpenVINO if export and Intel CPU\n if SETTINGS.get(\"openvino_msg\"):\n if is_intel():\n LOGGER.info(\n \"💡 ProTip: Export to OpenVINO format for best performance on Intel hardware.\"\n \" Learn more at https://docs.ultralytics.com/integrations/openvino/\"\n )\n SETTINGS[\"openvino_msg\"] = False\n\n # Input\n im = torch.zeros(self.args.batch, model.yaml.get(\"channels\", 3), *self.imgsz).to(self.device)\n file = Path(\n getattr(model, \"pt_path\", None) or getattr(model, \"yaml_file\", None) or model.yaml.get(\"yaml_file\", \"\")\n )\n if file.suffix in {\".yaml\", \".yml\"}:\n file = Path(file.name)\n\n # Update model\n model = deepcopy(model).to(self.device)\n for p in model.parameters():\n p.requires_grad = False\n model.eval()\n model.float()\n model = model.fuse()\n\n if imx:\n from ultralytics.utils.torch_utils import FXModel\n\n model = FXModel(model)\n for m in model.modules():\n if isinstance(m, Classify):\n m.export = True\n if isinstance(m, (Detect, RTDETRDecoder)): # includes all Detect subclasses like Segment, Pose, OBB\n m.dynamic = self.args.dynamic\n m.export = True\n m.format = self.args.format\n m.max_det = self.args.max_det\n m.xyxy = self.args.nms and not coreml\n elif isinstance(m, C2f) and not is_tf_format:\n # EdgeTPU does not support FlexSplitV while split provides cleaner ONNX graph\n m.forward = m.forward_split\n if isinstance(m, Detect) and imx:\n from ultralytics.utils.tal import make_anchors\n\n m.anchors, m.strides = (\n x.transpose(0, 1)\n for x in make_anchors(\n torch.cat([s / m.stride.unsqueeze(-1) for s in self.imgsz], dim=1), m.stride, 0.5\n )\n )\n\n y = None\n for _ in range(2): # dry runs\n y = NMSModel(model, self.args)(im) if self.args.nms and not (coreml or imx) else model(im)\n if self.args.half and onnx and self.device.type != \"cpu\":\n im, model = im.half(), model.half() # to FP16\n\n # Filter warnings\n warnings.filterwarnings(\"ignore\", category=torch.jit.TracerWarning) # suppress TracerWarning\n warnings.filterwarnings(\"ignore\", category=UserWarning) # suppress shape prim::Constant missing ONNX warning\n warnings.filterwarnings(\"ignore\", category=DeprecationWarning) # suppress CoreML np.bool deprecation warning\n\n # Assign\n self.im = im\n self.model = model\n self.file = file\n self.output_shape = (\n tuple(y.shape)\n if isinstance(y, torch.Tensor)\n else tuple(tuple(x.shape if isinstance(x, torch.Tensor) else []) for x in y)\n )\n self.pretty_name = Path(self.model.yaml.get(\"yaml_file\", self.file)).stem.replace(\"yolo\", \"YOLO\")\n data = model.args[\"data\"] if hasattr(model, \"args\") and isinstance(model.args, dict) else \"\"\n description = f\"Ultralytics {self.pretty_name} model {f'trained on {data}' if data else ''}\"\n self.metadata = {\n \"description\": description,\n \"author\": \"Ultralytics\",\n \"date\": datetime.now().isoformat(),\n \"version\": __version__,\n \"license\": \"AGPL-3.0 License (https://ultralytics.com/license)\",\n \"docs\": \"https://docs.ultralytics.com\",\n \"stride\": int(max(model.stride)),\n \"task\": model.task,\n \"batch\": self.args.batch,\n \"imgsz\": self.imgsz,\n \"names\": model.names,\n \"args\": {k: v for k, v in self.args if k in fmt_keys},\n \"channels\": model.yaml.get(\"channels\", 3),\n } # model metadata\n if dla is not None:\n self.metadata[\"dla\"] = dla # make sure `AutoBackend` uses correct dla device if it has one\n if model.task == \"pose\":\n self.metadata[\"kpt_shape\"] = model.model[-1].kpt_shape\n\n LOGGER.info(\n f\"\\n{colorstr('PyTorch:')} starting from '{file}' with input shape {tuple(im.shape)} BCHW and \"\n f\"output shape(s) {self.output_shape} ({file_size(file):.1f} MB)\"\n )\n self.run_callbacks(\"on_export_start\")\n # Exports\n f = [\"\"] * len(fmts) # exported filenames\n if jit or ncnn: # TorchScript\n f[0], _ = self.export_torchscript()\n if engine: # TensorRT required before ONNX\n f[1], _ = self.export_engine(dla=dla)\n if onnx: # ONNX\n f[2], _ = self.export_onnx()\n if xml: # OpenVINO\n f[3], _ = self.export_openvino()\n if coreml: # CoreML\n f[4], _ = self.export_coreml()\n if is_tf_format: # TensorFlow formats\n self.args.int8 |= edgetpu\n f[5], keras_model = self.export_saved_model()\n if pb or tfjs: # pb prerequisite to tfjs\n f[6], _ = self.export_pb(keras_model=keras_model)\n if tflite:\n f[7], _ = self.export_tflite()\n if edgetpu:\n f[8], _ = self.export_edgetpu(tflite_model=Path(f[5]) / f\"{self.file.stem}_full_integer_quant.tflite\")\n if tfjs:\n f[9], _ = self.export_tfjs()\n if paddle: # PaddlePaddle\n f[10], _ = self.export_paddle()\n if mnn: # MNN\n f[11], _ = self.export_mnn()\n if ncnn: # NCNN\n f[12], _ = self.export_ncnn()\n if imx:\n f[13], _ = self.export_imx()\n if rknn:\n f[14], _ = self.export_rknn()\n\n # Finish\n f = [str(x) for x in f if x] # filter out '' and None\n if any(f):\n f = str(Path(f[-1]))\n square = self.imgsz[0] == self.imgsz[1]\n s = (\n \"\"\n if square\n else f\"WARNING ⚠️ non-PyTorch val requires square images, 'imgsz={self.imgsz}' will not \"\n f\"work. Use export 'imgsz={max(self.imgsz)}' if val is required.\"\n )\n imgsz = self.imgsz[0] if square else str(self.imgsz)[1:-1].replace(\" \", \"\")\n predict_data = f\"data={data}\" if model.task == \"segment\" and pb else \"\"\n q = \"int8\" if self.args.int8 else \"half\" if self.args.half else \"\" # quantization\n LOGGER.info(\n f\"\\nExport complete ({time.time() - t:.1f}s)\"\n f\"\\nResults saved to {colorstr('bold', file.parent.resolve())}\"\n f\"\\nPredict: yolo predict task={model.task} model={f} imgsz={imgsz} {q} {predict_data}\"\n f\"\\nValidate: yolo val task={model.task} model={f} imgsz={imgsz} data={data} {q} {s}\"\n f\"\\nVisualize: https://netron.app\"\n )\n\n self.run_callbacks(\"on_export_end\")\n return f # return list of exported files/dirs\n\n def get_int8_calibration_dataloader(self, prefix=\"\"):\n \"\"\"Build and return a dataloader for calibration of INT8 models.\"\"\"\n LOGGER.info(f\"{prefix} collecting INT8 calibration images from 'data={self.args.data}'\")\n data = (check_cls_dataset if self.model.task == \"classify\" else check_det_dataset)(self.args.data)\n dataset = YOLODataset(\n data[self.args.split or \"val\"],\n data=data,\n fraction=self.args.fraction,\n task=self.model.task,\n imgsz=self.imgsz[0],\n augment=False,\n batch_size=self.args.batch,\n )\n n = len(dataset)\n if n < self.args.batch:\n raise ValueError(\n f\"The calibration dataset ({n} images) must have at least as many images as the batch size \"\n f\"('batch={self.args.batch}').\"\n )\n elif n < 300:\n LOGGER.warning(f\"{prefix} >300 images recommended for INT8 calibration, found {n} images.\")\n return build_dataloader(dataset, batch=self.args.batch, workers=0, drop_last=True) # required for batch loading\n\n @try_export\n def export_torchscript(self, prefix=colorstr(\"TorchScript:\")):\n \"\"\"Export YOLO model to TorchScript format.\"\"\"\n LOGGER.info(f\"\\n{prefix} starting export with torch {torch.__version__}...\")\n f = self.file.with_suffix(\".torchscript\")\n\n ts = torch.jit.trace(NMSModel(self.model, self.args) if self.args.nms else self.model, self.im, strict=False)\n extra_files = {\"config.txt\": json.dumps(self.metadata)} # torch._C.ExtraFilesMap()\n if self.args.optimize: # https://pytorch.org/tutorials/recipes/mobile_interpreter.html\n LOGGER.info(f\"{prefix} optimizing for mobile...\")\n from torch.utils.mobile_optimizer import optimize_for_mobile\n\n optimize_for_mobile(ts)._save_for_lite_interpreter(str(f), _extra_files=extra_files)\n else:\n ts.save(str(f), _extra_files=extra_files)\n return f, None\n\n @try_export\n def export_onnx(self, prefix=colorstr(\"ONNX:\")):\n \"\"\"Export YOLO model to ONNX format.\"\"\"\n requirements = [\"onnx>=1.12.0,<1.18.0\"]\n if self.args.simplify:\n requirements += [\"onnxslim>=0.1.59\", \"onnxruntime\" + (\"-gpu\" if torch.cuda.is_available() else \"\")]\n check_requirements(requirements)\n import onnx # noqa\n\n opset_version = self.args.opset or get_latest_opset()\n LOGGER.info(f\"\\n{prefix} starting export with onnx {onnx.__version__} opset {opset_version}...\")\n f = str(self.file.with_suffix(\".onnx\"))\n output_names = [\"output0\", \"output1\"] if isinstance(self.model, SegmentationModel) else [\"output0\"]\n dynamic = self.args.dynamic\n if dynamic:\n dynamic = {\"images\": {0: \"batch\", 2: \"height\", 3: \"width\"}} # shape(1,3,640,640)\n if isinstance(self.model, SegmentationModel):\n dynamic[\"output0\"] = {0: \"batch\", 2: \"anchors\"} # shape(1, 116, 8400)\n dynamic[\"output1\"] = {0: \"batch\", 2: \"mask_height\", 3: \"mask_width\"} # shape(1,32,160,160)\n elif isinstance(self.model, DetectionModel):\n dynamic[\"output0\"] = {0: \"batch\", 2: \"anchors\"} # shape(1, 84, 8400)\n if self.args.nms: # only batch size is dynamic with NMS\n dynamic[\"output0\"].pop(2)\n if self.args.nms and self.model.task == \"obb\":\n self.args.opset = opset_version # for NMSModel\n\n with arange_patch(self.args):\n export_onnx(\n NMSModel(self.model, self.args) if self.args.nms else self.model,\n self.im,\n f,\n opset=opset_version,\n input_names=[\"images\"],\n output_names=output_names,\n dynamic=dynamic or None,\n )\n\n # Checks\n model_onnx = onnx.load(f) # load onnx model\n\n # Simplify\n if self.args.simplify:\n try:\n import onnxslim\n\n LOGGER.info(f\"{prefix} slimming with onnxslim {onnxslim.__version__}...\")\n model_onnx = onnxslim.slim(model_onnx)\n\n except Exception as e:\n LOGGER.warning(f\"{prefix} simplifier failure: {e}\")\n\n # Metadata\n for k, v in self.metadata.items():\n meta = model_onnx.metadata_props.add()\n meta.key, meta.value = k, str(v)\n\n onnx.save(model_onnx, f)\n return f, model_onnx\n\n @try_export\n def export_openvino(self, prefix=colorstr(\"OpenVINO:\")):\n \"\"\"Export YOLO model to OpenVINO format.\"\"\"\n # OpenVINO <= 2025.1.0 error on macOS 15.4+: https://github.com/openvinotoolkit/openvino/issues/30023\"\n check_requirements(\"openvino>=2025.2.0\" if MACOS and MACOS_VERSION >= \"15.4\" else \"openvino>=2024.0.0\")\n import openvino as ov\n\n LOGGER.info(f\"\\n{prefix} starting export with openvino {ov.__version__}...\")\n assert TORCH_1_13, f\"OpenVINO export requires torch>=1.13.0 but torch=={torch.__version__} is installed\"\n ov_model = ov.convert_model(\n NMSModel(self.model, self.args) if self.args.nms else self.model,\n input=None if self.args.dynamic else [self.im.shape],\n example_input=self.im,\n )\n\n def serialize(ov_model, file):\n \"\"\"Set RT info, serialize, and save metadata YAML.\"\"\"\n ov_model.set_rt_info(\"YOLO\", [\"model_info\", \"model_type\"])\n ov_model.set_rt_info(True, [\"model_info\", \"reverse_input_channels\"])\n ov_model.set_rt_info(114, [\"model_info\", \"pad_value\"])\n ov_model.set_rt_info([255.0], [\"model_info\", \"scale_values\"])\n ov_model.set_rt_info(self.args.iou, [\"model_info\", \"iou_threshold\"])\n ov_model.set_rt_info([v.replace(\" \", \"_\") for v in self.model.names.values()], [\"model_info\", \"labels\"])\n if self.model.task != \"classify\":\n ov_model.set_rt_info(\"fit_to_window_letterbox\", [\"model_info\", \"resize_type\"])\n\n ov.save_model(ov_model, file, compress_to_fp16=self.args.half)\n YAML.save(Path(file).parent / \"metadata.yaml\", self.metadata) # add metadata.yaml\n\n if self.args.int8:\n fq = str(self.file).replace(self.file.suffix, f\"_int8_openvino_model{os.sep}\")\n fq_ov = str(Path(fq) / self.file.with_suffix(\".xml\").name)\n # INT8 requires nncf, nncf requires packaging>=23.2 https://github.com/openvinotoolkit/nncf/issues/3463\n check_requirements(\"packaging>=23.2\") # must be installed first to build nncf wheel\n check_requirements(\"nncf>=2.14.0\")\n import nncf\n\n def transform_fn(data_item) -> np.ndarray:\n \"\"\"Quantization transform function.\"\"\"\n data_item: torch.Tensor = data_item[\"img\"] if isinstance(data_item, dict) else data_item\n assert data_item.dtype == torch.uint8, \"Input image must be uint8 for the quantization preprocessing\"\n im = data_item.numpy().astype(np.float32) / 255.0 # uint8 to fp16/32 and 0-255 to 0.0-1.0\n return np.expand_dims(im, 0) if im.ndim == 3 else im\n\n # Generate calibration data for integer quantization\n ignored_scope = None\n if isinstance(self.model.model[-1], Detect):\n # Includes all Detect subclasses like Segment, Pose, OBB, WorldDetect, YOLOEDetect\n head_module_name = \".\".join(list(self.model.named_modules())[-1][0].split(\".\")[:2])\n ignored_scope = nncf.IgnoredScope( # ignore operations\n patterns=[\n f\".*{head_module_name}/.*/Add\",\n f\".*{head_module_name}/.*/Sub*\",\n f\".*{head_module_name}/.*/Mul*\",\n f\".*{head_module_name}/.*/Div*\",\n f\".*{head_module_name}\\\\.dfl.*\",\n ],\n types=[\"Sigmoid\"],\n )\n\n quantized_ov_model = nncf.quantize(\n model=ov_model,\n calibration_dataset=nncf.Dataset(self.get_int8_calibration_dataloader(prefix), transform_fn),\n preset=nncf.QuantizationPreset.MIXED,\n ignored_scope=ignored_scope,\n )\n serialize(quantized_ov_model, fq_ov)\n return fq, None\n\n f = str(self.file).replace(self.file.suffix, f\"_openvino_model{os.sep}\")\n f_ov = str(Path(f) / self.file.with_suffix(\".xml\").name)\n\n serialize(ov_model, f_ov)\n return f, None\n\n @try_export\n def export_paddle(self, prefix=colorstr(\"PaddlePaddle:\")):\n \"\"\"Export YOLO model to PaddlePaddle format.\"\"\"\n assert not IS_JETSON, \"Jetson Paddle exports not supported yet\"\n check_requirements(\n (\n \"paddlepaddle-gpu\"\n if torch.cuda.is_available()\n else \"paddlepaddle==3.0.0\" # pin 3.0.0 for ARM64\n if ARM64\n else \"paddlepaddle>=3.0.0\",\n \"x2paddle\",\n )\n )\n import x2paddle # noqa\n from x2paddle.convert import pytorch2paddle # noqa\n\n LOGGER.info(f\"\\n{prefix} starting export with X2Paddle {x2paddle.__version__}...\")\n f = str(self.file).replace(self.file.suffix, f\"_paddle_model{os.sep}\")\n\n pytorch2paddle(module=self.model, save_dir=f, jit_type=\"trace\", input_examples=[self.im]) # export\n YAML.save(Path(f) / \"metadata.yaml\", self.metadata) # add metadata.yaml\n return f, None\n\n @try_export\n def export_mnn(self, prefix=colorstr(\"MNN:\")):\n \"\"\"Export YOLO model to MNN format using MNN https://github.com/alibaba/MNN.\"\"\"\n f_onnx, _ = self.export_onnx() # get onnx model first\n\n check_requirements(\"MNN>=2.9.6\")\n import MNN # noqa\n from MNN.tools import mnnconvert\n\n # Setup and checks\n LOGGER.info(f\"\\n{prefix} starting export with MNN {MNN.version()}...\")\n assert Path(f_onnx).exists(), f\"failed to export ONNX file: {f_onnx}\"\n f = str(self.file.with_suffix(\".mnn\")) # MNN model file\n args = [\"\", \"-f\", \"ONNX\", \"--modelFile\", f_onnx, \"--MNNModel\", f, \"--bizCode\", json.dumps(self.metadata)]\n if self.args.int8:\n args.extend((\"--weightQuantBits\", \"8\"))\n if self.args.half:\n args.append(\"--fp16\")\n mnnconvert.convert(args)\n # remove scratch file for model convert optimize\n convert_scratch = Path(self.file.parent / \".__convert_external_data.bin\")\n if convert_scratch.exists():\n convert_scratch.unlink()\n return f, None\n\n @try_export\n def export_ncnn(self, prefix=colorstr(\"NCNN:\")):\n \"\"\"Export YOLO model to NCNN format using PNNX https://github.com/pnnx/pnnx.\"\"\"\n check_requirements(\"ncnn\")\n import ncnn # noqa\n\n LOGGER.info(f\"\\n{prefix} starting export with NCNN {ncnn.__version__}...\")\n f = Path(str(self.file).replace(self.file.suffix, f\"_ncnn_model{os.sep}\"))\n f_ts = self.file.with_suffix(\".torchscript\")\n\n name = Path(\"pnnx.exe\" if WINDOWS else \"pnnx\") # PNNX filename\n pnnx = name if name.is_file() else (ROOT / name)\n if not pnnx.is_file():\n LOGGER.warning(\n f\"{prefix} PNNX not found. Attempting to download binary file from \"\n \"https://github.com/pnnx/pnnx/.\\nNote PNNX Binary file must be placed in current working directory \"\n f\"or in {ROOT}. See PNNX repo for full installation instructions.\"\n )\n system = \"macos\" if MACOS else \"windows\" if WINDOWS else \"linux-aarch64\" if ARM64 else \"linux\"\n try:\n release, assets = get_github_assets(repo=\"pnnx/pnnx\")\n asset = [x for x in assets if f\"{system}.zip\" in x][0]\n assert isinstance(asset, str), \"Unable to retrieve PNNX repo assets\" # i.e. pnnx-20240410-macos.zip\n LOGGER.info(f\"{prefix} successfully found latest PNNX asset file {asset}\")\n except Exception as e:\n release = \"20240410\"\n asset = f\"pnnx-{release}-{system}.zip\"\n LOGGER.warning(f\"{prefix} PNNX GitHub assets not found: {e}, using default {asset}\")\n unzip_dir = safe_download(f\"https://github.com/pnnx/pnnx/releases/download/{release}/{asset}\", delete=True)\n if check_is_path_safe(Path.cwd(), unzip_dir): # avoid path traversal security vulnerability\n shutil.move(src=unzip_dir / name, dst=pnnx) # move binary to ROOT\n pnnx.chmod(0o777) # set read, write, and execute permissions for everyone\n shutil.rmtree(unzip_dir) # delete unzip dir\n\n ncnn_args = [\n f\"ncnnparam={f / 'model.ncnn.param'}\",\n f\"ncnnbin={f / 'model.ncnn.bin'}\",\n f\"ncnnpy={f / 'model_ncnn.py'}\",\n ]\n\n pnnx_args = [\n f\"pnnxparam={f / 'model.pnnx.param'}\",\n f\"pnnxbin={f / 'model.pnnx.bin'}\",\n f\"pnnxpy={f / 'model_pnnx.py'}\",\n f\"pnnxonnx={f / 'model.pnnx.onnx'}\",\n ]\n\n cmd = [\n str(pnnx),\n str(f_ts),\n *ncnn_args,\n *pnnx_args,\n f\"fp16={int(self.args.half)}\",\n f\"device={self.device.type}\",\n f'inputshape=\"{[self.args.batch, 3, *self.imgsz]}\"',\n ]\n f.mkdir(exist_ok=True) # make ncnn_model directory\n LOGGER.info(f\"{prefix} running '{' '.join(cmd)}'\")\n subprocess.run(cmd, check=True)\n\n # Remove debug files\n pnnx_files = [x.rsplit(\"=\", 1)[-1] for x in pnnx_args]\n for f_debug in (\"debug.bin\", \"debug.param\", \"debug2.bin\", \"debug2.param\", *pnnx_files):\n Path(f_debug).unlink(missing_ok=True)\n\n YAML.save(f / \"metadata.yaml\", self.metadata) # add metadata.yaml\n return str(f), None\n\n @try_export\n def export_coreml(self, prefix=colorstr(\"CoreML:\")):\n \"\"\"Export YOLO model to CoreML format.\"\"\"\n mlmodel = self.args.format.lower() == \"mlmodel\" # legacy *.mlmodel export format requested\n check_requirements(\"coremltools>=8.0\")\n import coremltools as ct # noqa\n\n LOGGER.info(f\"\\n{prefix} starting export with coremltools {ct.__version__}...\")\n assert not WINDOWS, \"CoreML export is not supported on Windows, please run on macOS or Linux.\"\n assert self.args.batch == 1, \"CoreML batch sizes > 1 are not supported. Please retry at 'batch=1'.\"\n f = self.file.with_suffix(\".mlmodel\" if mlmodel else \".mlpackage\")\n if f.is_dir():\n shutil.rmtree(f)\n\n bias = [0.0, 0.0, 0.0]\n scale = 1 / 255\n classifier_config = None\n if self.model.task == \"classify\":\n classifier_config = ct.ClassifierConfig(list(self.model.names.values()))\n model = self.model\n elif self.model.task == \"detect\":\n model = IOSDetectModel(self.model, self.im) if self.args.nms else self.model\n else:\n if self.args.nms:\n LOGGER.warning(f\"{prefix} 'nms=True' is only available for Detect models like 'yolo11n.pt'.\")\n # TODO CoreML Segment and Pose model pipelining\n model = self.model\n ts = torch.jit.trace(model.eval(), self.im, strict=False) # TorchScript model\n\n # Based on apple's documentation it is better to leave out the minimum_deployment target and let that get set\n # Internally based on the model conversion and output type.\n # Setting minimum_depoloyment_target >= iOS16 will require setting compute_precision=ct.precision.FLOAT32.\n # iOS16 adds in better support for FP16, but none of the CoreML NMS specifications handle FP16 as input.\n ct_model = ct.convert(\n ts,\n inputs=[ct.ImageType(\"image\", shape=self.im.shape, scale=scale, bias=bias)], # expects ct.TensorType\n classifier_config=classifier_config,\n convert_to=\"neuralnetwork\" if mlmodel else \"mlprogram\",\n )\n bits, mode = (8, \"kmeans\") if self.args.int8 else (16, \"linear\") if self.args.half else (32, None)\n if bits < 32:\n if \"kmeans\" in mode:\n check_requirements(\"scikit-learn\") # scikit-learn package required for k-means quantization\n if mlmodel:\n ct_model = ct.models.neural_network.quantization_utils.quantize_weights(ct_model, bits, mode)\n elif bits == 8: # mlprogram already quantized to FP16\n import coremltools.optimize.coreml as cto\n\n op_config = cto.OpPalettizerConfig(mode=\"kmeans\", nbits=bits, weight_threshold=512)\n config = cto.OptimizationConfig(global_config=op_config)\n ct_model = cto.palettize_weights(ct_model, config=config)\n if self.args.nms and self.model.task == \"detect\":\n if mlmodel:\n weights_dir = None\n else:\n ct_model.save(str(f)) # save otherwise weights_dir does not exist\n weights_dir = str(f / \"Data/com.apple.CoreML/weights\")\n ct_model = self._pipeline_coreml(ct_model, weights_dir=weights_dir)\n\n m = self.metadata # metadata dict\n ct_model.short_description = m.pop(\"description\")\n ct_model.author = m.pop(\"author\")\n ct_model.license = m.pop(\"license\")\n ct_model.version = m.pop(\"version\")\n ct_model.user_defined_metadata.update({k: str(v) for k, v in m.items()})\n if self.model.task == \"classify\":\n ct_model.user_defined_metadata.update({\"com.apple.coreml.model.preview.type\": \"imageClassifier\"})\n\n try:\n ct_model.save(str(f)) # save *.mlpackage\n except Exception as e:\n LOGGER.warning(\n f\"{prefix} CoreML export to *.mlpackage failed ({e}), reverting to *.mlmodel export. \"\n f\"Known coremltools Python 3.11 and Windows bugs https://github.com/apple/coremltools/issues/1928.\"\n )\n f = f.with_suffix(\".mlmodel\")\n ct_model.save(str(f))\n return f, ct_model\n\n @try_export\n def export_engine(self, dla=None, prefix=colorstr(\"TensorRT:\")):\n \"\"\"Export YOLO model to TensorRT format https://developer.nvidia.com/tensorrt.\"\"\"\n assert self.im.device.type != \"cpu\", \"export running on CPU but must be on GPU, i.e. use 'device=0'\"\n f_onnx, _ = self.export_onnx() # run before TRT import https://github.com/ultralytics/ultralytics/issues/7016\n\n try:\n import tensorrt as trt # noqa\n except ImportError:\n if LINUX:\n check_requirements(\"tensorrt>7.0.0,!=10.1.0\")\n import tensorrt as trt # noqa\n check_version(trt.__version__, \">=7.0.0\", hard=True)\n check_version(trt.__version__, \"!=10.1.0\", msg=\"https://github.com/ultralytics/ultralytics/pull/14239\")\n\n # Setup and checks\n LOGGER.info(f\"\\n{prefix} starting export with TensorRT {trt.__version__}...\")\n assert Path(f_onnx).exists(), f\"failed to export ONNX file: {f_onnx}\"\n f = self.file.with_suffix(\".engine\") # TensorRT engine file\n export_engine(\n f_onnx,\n f,\n self.args.workspace,\n self.args.half,\n self.args.int8,\n self.args.dynamic,\n self.im.shape,\n dla=dla,\n dataset=self.get_int8_calibration_dataloader(prefix) if self.args.int8 else None,\n metadata=self.metadata,\n verbose=self.args.verbose,\n prefix=prefix,\n )\n\n return f, None\n\n @try_export\n def export_saved_model(self, prefix=colorstr(\"TensorFlow SavedModel:\")):\n \"\"\"Export YOLO model to TensorFlow SavedModel format.\"\"\"\n cuda = torch.cuda.is_available()\n try:\n import tensorflow as tf # noqa\n except ImportError:\n check_requirements(\"tensorflow>=2.0.0\")\n import tensorflow as tf # noqa\n check_requirements(\n (\n \"tf_keras\", # required by 'onnx2tf' package\n \"sng4onnx>=1.0.1\", # required by 'onnx2tf' package\n \"onnx_graphsurgeon>=0.3.26\", # required by 'onnx2tf' package\n \"ai-edge-litert>=1.2.0,<1.4.0\", # required by 'onnx2tf' package\n \"onnx>=1.12.0,<1.18.0\",\n \"onnx2tf>=1.26.3\",\n \"onnxslim>=0.1.59\",\n \"onnxruntime-gpu\" if cuda else \"onnxruntime\",\n \"protobuf>=5\",\n ),\n cmds=\"--extra-index-url https://pypi.ngc.nvidia.com\", # onnx_graphsurgeon only on NVIDIA\n )\n\n LOGGER.info(f\"\\n{prefix} starting export with tensorflow {tf.__version__}...\")\n check_version(\n tf.__version__,\n \">=2.0.0\",\n name=\"tensorflow\",\n verbose=True,\n msg=\"https://github.com/ultralytics/ultralytics/issues/5161\",\n )\n f = Path(str(self.file).replace(self.file.suffix, \"_saved_model\"))\n if f.is_dir():\n shutil.rmtree(f) # delete output folder\n\n # Pre-download calibration file to fix https://github.com/PINTO0309/onnx2tf/issues/545\n onnx2tf_file = Path(\"calibration_image_sample_data_20x128x128x3_float32.npy\")\n if not onnx2tf_file.exists():\n attempt_download_asset(f\"{onnx2tf_file}.zip\", unzip=True, delete=True)\n\n # Export to ONNX\n self.args.simplify = True\n f_onnx, _ = self.export_onnx()\n\n # Export to TF\n np_data = None\n if self.args.int8:\n tmp_file = f / \"tmp_tflite_int8_calibration_images.npy\" # int8 calibration images file\n if self.args.data:\n f.mkdir()\n images = [batch[\"img\"] for batch in self.get_int8_calibration_dataloader(prefix)]\n images = torch.nn.functional.interpolate(torch.cat(images, 0).float(), size=self.imgsz).permute(\n 0, 2, 3, 1\n )\n np.save(str(tmp_file), images.numpy().astype(np.float32)) # BHWC\n np_data = [[\"images\", tmp_file, [[[[0, 0, 0]]]], [[[[255, 255, 255]]]]]]\n\n import onnx2tf # scoped for after ONNX export for reduced conflict during import\n\n LOGGER.info(f\"{prefix} starting TFLite export with onnx2tf {onnx2tf.__version__}...\")\n keras_model = onnx2tf.convert(\n input_onnx_file_path=f_onnx,\n output_folder_path=str(f),\n not_use_onnxsim=True,\n verbosity=\"error\", # note INT8-FP16 activation bug https://github.com/ultralytics/ultralytics/issues/15873\n output_integer_quantized_tflite=self.args.int8,\n quant_type=\"per-tensor\", # \"per-tensor\" (faster) or \"per-channel\" (slower but more accurate)\n custom_input_op_name_np_data_path=np_data,\n enable_batchmatmul_unfold=True, # fix lower no. of detected objects on GPU delegate\n output_signaturedefs=True, # fix error with Attention block group convolution\n disable_group_convolution=self.args.format in {\"tfjs\", \"edgetpu\"}, # fix error with group convolution\n optimization_for_gpu_delegate=True,\n )\n YAML.save(f / \"metadata.yaml\", self.metadata) # add metadata.yaml\n\n # Remove/rename TFLite models\n if self.args.int8:\n tmp_file.unlink(missing_ok=True)\n for file in f.rglob(\"*_dynamic_range_quant.tflite\"):\n file.rename(file.with_name(file.stem.replace(\"_dynamic_range_quant\", \"_int8\") + file.suffix))\n for file in f.rglob(\"*_integer_quant_with_int16_act.tflite\"):\n file.unlink() # delete extra fp16 activation TFLite files\n\n # Add TFLite metadata\n for file in f.rglob(\"*.tflite\"):\n f.unlink() if \"quant_with_int16_act.tflite\" in str(f) else self._add_tflite_metadata(file)\n\n return str(f), keras_model # or keras_model = tf.saved_model.load(f, tags=None, options=None)\n\n @try_export\n def export_pb(self, keras_model, prefix=colorstr(\"TensorFlow GraphDef:\")):\n \"\"\"Export YOLO model to TensorFlow GraphDef *.pb format https://github.com/leimao/Frozen-Graph-TensorFlow.\"\"\"\n import tensorflow as tf # noqa\n from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2 # noqa\n\n LOGGER.info(f\"\\n{prefix} starting export with tensorflow {tf.__version__}...\")\n f = self.file.with_suffix(\".pb\")\n\n m = tf.function(lambda x: keras_model(x)) # full model\n m = m.get_concrete_function(tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype))\n frozen_func = convert_variables_to_constants_v2(m)\n frozen_func.graph.as_graph_def()\n tf.io.write_graph(graph_or_graph_def=frozen_func.graph, logdir=str(f.parent), name=f.name, as_text=False)\n return f, None\n\n @try_export\n def export_tflite(self, prefix=colorstr(\"TensorFlow Lite:\")):\n \"\"\"Export YOLO model to TensorFlow Lite format.\"\"\"\n # BUG https://github.com/ultralytics/ultralytics/issues/13436\n import tensorflow as tf # noqa\n\n LOGGER.info(f\"\\n{prefix} starting export with tensorflow {tf.__version__}...\")\n saved_model = Path(str(self.file).replace(self.file.suffix, \"_saved_model\"))\n if self.args.int8:\n f = saved_model / f\"{self.file.stem}_int8.tflite\" # fp32 in/out\n elif self.args.half:\n f = saved_model / f\"{self.file.stem}_float16.tflite\" # fp32 in/out\n else:\n f = saved_model / f\"{self.file.stem}_float32.tflite\"\n return str(f), None\n\n @try_export\n def export_edgetpu(self, tflite_model=\"\", prefix=colorstr(\"Edge TPU:\")):\n \"\"\"Export YOLO model to Edge TPU format https://coral.ai/docs/edgetpu/models-intro/.\"\"\"\n cmd = \"edgetpu_compiler --version\"\n help_url = \"https://coral.ai/docs/edgetpu/compiler/\"\n assert LINUX, f\"export only supported on Linux. See {help_url}\"\n if subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True).returncode != 0:\n LOGGER.info(f\"\\n{prefix} export requires Edge TPU compiler. Attempting install from {help_url}\")\n for c in (\n \"curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -\",\n 'echo \"deb https://packages.cloud.google.com/apt coral-edgetpu-stable main\" | '\n \"sudo tee /etc/apt/sources.list.d/coral-edgetpu.list\",\n \"sudo apt-get update\",\n \"sudo apt-get install edgetpu-compiler\",\n ):\n subprocess.run(c if is_sudo_available() else c.replace(\"sudo \", \"\"), shell=True, check=True)\n ver = subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().rsplit(maxsplit=1)[-1]\n\n LOGGER.info(f\"\\n{prefix} starting export with Edge TPU compiler {ver}...\")\n f = str(tflite_model).replace(\".tflite\", \"_edgetpu.tflite\") # Edge TPU model\n\n cmd = (\n \"edgetpu_compiler \"\n f'--out_dir \"{Path(f).parent}\" '\n \"--show_operations \"\n \"--search_delegate \"\n \"--delegate_search_step 30 \"\n \"--timeout_sec 180 \"\n f'\"{tflite_model}\"'\n )\n LOGGER.info(f\"{prefix} running '{cmd}'\")\n subprocess.run(cmd, shell=True)\n self._add_tflite_metadata(f)\n return f, None\n\n @try_export\n def export_tfjs(self, prefix=colorstr(\"TensorFlow.js:\")):\n \"\"\"Export YOLO model to TensorFlow.js format.\"\"\"\n check_requirements(\"tensorflowjs\")\n import tensorflow as tf\n import tensorflowjs as tfjs # noqa\n\n LOGGER.info(f\"\\n{prefix} starting export with tensorflowjs {tfjs.__version__}...\")\n f = str(self.file).replace(self.file.suffix, \"_web_model\") # js dir\n f_pb = str(self.file.with_suffix(\".pb\")) # *.pb path\n\n gd = tf.Graph().as_graph_def() # TF GraphDef\n with open(f_pb, \"rb\") as file:\n gd.ParseFromString(file.read())\n outputs = \",\".join(gd_outputs(gd))\n LOGGER.info(f\"\\n{prefix} output node names: {outputs}\")\n\n quantization = \"--quantize_float16\" if self.args.half else \"--quantize_uint8\" if self.args.int8 else \"\"\n with spaces_in_path(f_pb) as fpb_, spaces_in_path(f) as f_: # exporter can not handle spaces in path\n cmd = (\n \"tensorflowjs_converter \"\n f'--input_format=tf_frozen_model {quantization} --output_node_names={outputs} \"{fpb_}\" \"{f_}\"'\n )\n LOGGER.info(f\"{prefix} running '{cmd}'\")\n subprocess.run(cmd, shell=True)\n\n if \" \" in f:\n LOGGER.warning(f\"{prefix} your model may not work correctly with spaces in path '{f}'.\")\n\n # Add metadata\n YAML.save(Path(f) / \"metadata.yaml\", self.metadata) # add metadata.yaml\n return f, None\n\n @try_export\n def export_rknn(self, prefix=colorstr(\"RKNN:\")):\n \"\"\"Export YOLO model to RKNN format.\"\"\"\n LOGGER.info(f\"\\n{prefix} starting export with rknn-toolkit2...\")\n\n check_requirements(\"rknn-toolkit2\")\n if IS_COLAB:\n # Prevent 'exit' from closing the notebook https://github.com/airockchip/rknn-toolkit2/issues/259\n import builtins\n\n builtins.exit = lambda: None\n\n from rknn.api import RKNN\n\n f, _ = self.export_onnx()\n export_path = Path(f\"{Path(f).stem}_rknn_model\")\n export_path.mkdir(exist_ok=True)\n\n rknn = RKNN(verbose=False)\n rknn.config(mean_values=[[0, 0, 0]], std_values=[[255, 255, 255]], target_platform=self.args.name)\n rknn.load_onnx(model=f)\n rknn.build(do_quantization=False) # TODO: Add quantization support\n f = f.replace(\".onnx\", f\"-{self.args.name}.rknn\")\n rknn.export_rknn(f\"{export_path / f}\")\n YAML.save(export_path / \"metadata.yaml\", self.metadata)\n return export_path, None\n\n @try_export\n def export_imx(self, prefix=colorstr(\"IMX:\")):\n \"\"\"Export YOLO model to IMX format.\"\"\"\n gptq = False\n assert LINUX, (\n \"export only supported on Linux. \"\n \"See https://developer.aitrios.sony-semicon.com/en/raspberrypi-ai-camera/documentation/imx500-converter\"\n )\n if getattr(self.model, \"end2end\", False):\n raise ValueError(\"IMX export is not supported for end2end models.\")\n check_requirements((\"model-compression-toolkit>=2.4.1\", \"sony-custom-layers>=0.3.0\", \"edge-mdt-tpc>=1.1.0\"))\n check_requirements(\"imx500-converter[pt]>=3.16.1\") # Separate requirements for imx500-converter\n check_requirements(\"mct-quantizers>=1.6.0\") # Separate for compatibility with model-compression-toolkit\n\n import model_compression_toolkit as mct\n import onnx\n from edgemdt_tpc import get_target_platform_capabilities\n from sony_custom_layers.pytorch import multiclass_nms_with_indices\n\n LOGGER.info(f\"\\n{prefix} starting export with model_compression_toolkit {mct.__version__}...\")\n\n # Install Java>=17\n try:\n java_output = subprocess.run([\"java\", \"--version\"], check=True, capture_output=True).stdout.decode()\n version_match = re.search(r\"(?:openjdk|java) (\\d+)\", java_output)\n java_version = int(version_match.group(1)) if version_match else 0\n assert java_version >= 17, \"Java version too old\"\n except (FileNotFoundError, subprocess.CalledProcessError, AssertionError):\n cmd = ([\"sudo\"] if is_sudo_available() else []) + [\"apt\", \"install\", \"-y\", \"openjdk-21-jre\"]\n subprocess.run(cmd, check=True)\n\n def representative_dataset_gen(dataloader=self.get_int8_calibration_dataloader(prefix)):\n for batch in dataloader:\n img = batch[\"img\"]\n img = img / 255.0\n yield [img]\n\n tpc = get_target_platform_capabilities(tpc_version=\"4.0\", device_type=\"imx500\")\n\n bit_cfg = mct.core.BitWidthConfig()\n if \"C2PSA\" in self.model.__str__(): # YOLO11\n if self.model.task == \"detect\":\n layer_names = [\"sub\", \"mul_2\", \"add_14\", \"cat_21\"]\n weights_memory = 2585350.2439\n n_layers = 238 # 238 layers for fused YOLO11n\n elif self.model.task == \"pose\":\n layer_names = [\"sub\", \"mul_2\", \"add_14\", \"cat_22\", \"cat_23\", \"mul_4\", \"add_15\"]\n weights_memory = 2437771.67\n n_layers = 257 # 257 layers for fused YOLO11n-pose\n else: # YOLOv8\n if self.model.task == \"detect\":\n layer_names = [\"sub\", \"mul\", \"add_6\", \"cat_17\"]\n weights_memory = 2550540.8\n n_layers = 168 # 168 layers for fused YOLOv8n\n elif self.model.task == \"pose\":\n layer_names = [\"add_7\", \"mul_2\", \"cat_19\", \"mul\", \"sub\", \"add_6\", \"cat_18\"]\n weights_memory = 2482451.85\n n_layers = 187 # 187 layers for fused YOLO11n-pose\n\n # Check if the model has the expected number of layers\n if len(list(self.model.modules())) != n_layers:\n raise ValueError(\"IMX export only supported for YOLOv8n and YOLO11n models.\")\n\n for layer_name in layer_names:\n bit_cfg.set_manual_activation_bit_width([mct.core.common.network_editors.NodeNameFilter(layer_name)], 16)\n\n config = mct.core.CoreConfig(\n mixed_precision_config=mct.core.MixedPrecisionQuantizationConfig(num_of_images=10),\n quantization_config=mct.core.QuantizationConfig(concat_threshold_update=True),\n bit_width_config=bit_cfg,\n )\n\n resource_utilization = mct.core.ResourceUtilization(weights_memory=weights_memory)\n\n quant_model = (\n mct.gptq.pytorch_gradient_post_training_quantization( # Perform Gradient-Based Post Training Quantization\n model=self.model,\n representative_data_gen=representative_dataset_gen,\n target_resource_utilization=resource_utilization,\n gptq_config=mct.gptq.get_pytorch_gptq_config(\n n_epochs=1000, use_hessian_based_weights=False, use_hessian_sample_attention=False\n ),\n core_config=config,\n target_platform_capabilities=tpc,\n )[0]\n if gptq\n else mct.ptq.pytorch_post_training_quantization( # Perform post training quantization\n in_module=self.model,\n representative_data_gen=representative_dataset_gen,\n target_resource_utilization=resource_utilization,\n core_config=config,\n target_platform_capabilities=tpc,\n )[0]\n )\n\n class NMSWrapper(torch.nn.Module):\n \"\"\"Wrap PyTorch Module with multiclass_nms layer from sony_custom_layers.\"\"\"\n\n def __init__(\n self,\n model: torch.nn.Module,\n score_threshold: float = 0.001,\n iou_threshold: float = 0.7,\n max_detections: int = 300,\n task: str = \"detect\",\n ):\n \"\"\"\n Initialize NMSWrapper with PyTorch Module and NMS parameters.\n\n Args:\n model (torch.nn.Module): Model instance.\n score_threshold (float): Score threshold for non-maximum suppression.\n iou_threshold (float): Intersection over union threshold for non-maximum suppression.\n max_detections (int): The number of detections to return.\n task (str): Task type, either 'detect' or 'pose'.\n \"\"\"\n super().__init__()\n self.model = model\n self.score_threshold = score_threshold\n self.iou_threshold = iou_threshold\n self.max_detections = max_detections\n self.task = task\n\n def forward(self, images):\n \"\"\"Forward pass with model inference and NMS post-processing.\"\"\"\n # model inference\n outputs = self.model(images)\n\n boxes, scores = outputs[0], outputs[1]\n nms_outputs = multiclass_nms_with_indices(\n boxes=boxes,\n scores=scores,\n score_threshold=self.score_threshold,\n iou_threshold=self.iou_threshold,\n max_detections=self.max_detections,\n )\n if self.task == \"pose\":\n kpts = outputs[2] # (bs, max_detections, kpts 17*3)\n out_kpts = torch.gather(kpts, 1, nms_outputs.indices.unsqueeze(-1).expand(-1, -1, kpts.size(-1)))\n return nms_outputs.boxes, nms_outputs.scores, nms_outputs.labels, out_kpts\n return nms_outputs\n\n quant_model = NMSWrapper(\n model=quant_model,\n score_threshold=self.args.conf or 0.001,\n iou_threshold=self.args.iou,\n max_detections=self.args.max_det,\n task=self.model.task,\n ).to(self.device)\n\n f = Path(str(self.file).replace(self.file.suffix, \"_imx_model\"))\n f.mkdir(exist_ok=True)\n onnx_model = f / Path(str(self.file.name).replace(self.file.suffix, \"_imx.onnx\")) # js dir\n mct.exporter.pytorch_export_model(\n model=quant_model, save_model_path=onnx_model, repr_dataset=representative_dataset_gen\n )\n\n model_onnx = onnx.load(onnx_model) # load onnx model\n for k, v in self.metadata.items():\n meta = model_onnx.metadata_props.add()\n meta.key, meta.value = k, str(v)\n\n onnx.save(model_onnx, onnx_model)\n\n subprocess.run(\n [\"imxconv-pt\", \"-i\", str(onnx_model), \"-o\", str(f), \"--no-input-persistency\", \"--overwrite-output\"],\n check=True,\n )\n\n # Needed for imx models.\n with open(f / \"labels.txt\", \"w\", encoding=\"utf-8\") as file:\n file.writelines([f\"{name}\\n\" for _, name in self.model.names.items()])\n\n return f, None\n\n def _add_tflite_metadata(self, file):\n \"\"\"Add metadata to *.tflite models per https://ai.google.dev/edge/litert/models/metadata.\"\"\"\n import zipfile\n\n with zipfile.ZipFile(file, \"a\", zipfile.ZIP_DEFLATED) as zf:\n zf.writestr(\"metadata.json\", json.dumps(self.metadata, indent=2))\n\n def _pipeline_coreml(self, model, weights_dir=None, prefix=colorstr(\"CoreML Pipeline:\")):\n \"\"\"Create CoreML pipeline with NMS for YOLO detection models.\"\"\"\n import coremltools as ct # noqa\n\n LOGGER.info(f\"{prefix} starting pipeline with coremltools {ct.__version__}...\")\n _, _, h, w = list(self.im.shape) # BCHW\n\n # Output shapes\n spec = model.get_spec()\n out0, out1 = iter(spec.description.output)\n if MACOS:\n from PIL import Image\n\n img = Image.new(\"RGB\", (w, h)) # w=192, h=320\n out = model.predict({\"image\": img})\n out0_shape = out[out0.name].shape # (3780, 80)\n out1_shape = out[out1.name].shape # (3780, 4)\n else: # linux and windows can not run model.predict(), get sizes from PyTorch model output y\n out0_shape = self.output_shape[2], self.output_shape[1] - 4 # (3780, 80)\n out1_shape = self.output_shape[2], 4 # (3780, 4)\n\n # Checks\n names = self.metadata[\"names\"]\n nx, ny = spec.description.input[0].type.imageType.width, spec.description.input[0].type.imageType.height\n _, nc = out0_shape # number of anchors, number of classes\n assert len(names) == nc, f\"{len(names)} names found for nc={nc}\" # check\n\n # Define output shapes (missing)\n out0.type.multiArrayType.shape[:] = out0_shape # (3780, 80)\n out1.type.multiArrayType.shape[:] = out1_shape # (3780, 4)\n\n # Model from spec\n model = ct.models.MLModel(spec, weights_dir=weights_dir)\n\n # 3. Create NMS protobuf\n nms_spec = ct.proto.Model_pb2.Model()\n nms_spec.specificationVersion = spec.specificationVersion\n for i in range(2):\n decoder_output = model._spec.description.output[i].SerializeToString()\n nms_spec.description.input.add()\n nms_spec.description.input[i].ParseFromString(decoder_output)\n nms_spec.description.output.add()\n nms_spec.description.output[i].ParseFromString(decoder_output)\n\n nms_spec.description.output[0].name = \"confidence\"\n nms_spec.description.output[1].name = \"coordinates\"\n\n output_sizes = [nc, 4]\n for i in range(2):\n ma_type = nms_spec.description.output[i].type.multiArrayType\n ma_type.shapeRange.sizeRanges.add()\n ma_type.shapeRange.sizeRanges[0].lowerBound = 0\n ma_type.shapeRange.sizeRanges[0].upperBound = -1\n ma_type.shapeRange.sizeRanges.add()\n ma_type.shapeRange.sizeRanges[1].lowerBound = output_sizes[i]\n ma_type.shapeRange.sizeRanges[1].upperBound = output_sizes[i]\n del ma_type.shape[:]\n\n nms = nms_spec.nonMaximumSuppression\n nms.confidenceInputFeatureName = out0.name # 1x507x80\n nms.coordinatesInputFeatureName = out1.name # 1x507x4\n nms.confidenceOutputFeatureName = \"confidence\"\n nms.coordinatesOutputFeatureName = \"coordinates\"\n nms.iouThresholdInputFeatureName = \"iouThreshold\"\n nms.confidenceThresholdInputFeatureName = \"confidenceThreshold\"\n nms.iouThreshold = self.args.iou\n nms.confidenceThreshold = self.args.conf\n nms.pickTop.perClass = True\n nms.stringClassLabels.vector.extend(names.values())\n nms_model = ct.models.MLModel(nms_spec)\n\n # 4. Pipeline models together\n pipeline = ct.models.pipeline.Pipeline(\n input_features=[\n (\"image\", ct.models.datatypes.Array(3, ny, nx)),\n (\"iouThreshold\", ct.models.datatypes.Double()),\n (\"confidenceThreshold\", ct.models.datatypes.Double()),\n ],\n output_features=[\"confidence\", \"coordinates\"],\n )\n pipeline.add_model(model)\n pipeline.add_model(nms_model)\n\n # Correct datatypes\n pipeline.spec.description.input[0].ParseFromString(model._spec.description.input[0].SerializeToString())\n pipeline.spec.description.output[0].ParseFromString(nms_model._spec.description.output[0].SerializeToString())\n pipeline.spec.description.output[1].ParseFromString(nms_model._spec.description.output[1].SerializeToString())\n\n # Update metadata\n pipeline.spec.specificationVersion = spec.specificationVersion\n pipeline.spec.description.metadata.userDefined.update(\n {\"IoU threshold\": str(nms.iouThreshold), \"Confidence threshold\": str(nms.confidenceThreshold)}\n )\n\n # Save the model\n model = ct.models.MLModel(pipeline.spec, weights_dir=weights_dir)\n model.input_description[\"image\"] = \"Input image\"\n model.input_description[\"iouThreshold\"] = f\"(optional) IoU threshold override (default: {nms.iouThreshold})\"\n model.input_description[\"confidenceThreshold\"] = (\n f\"(optional) Confidence threshold override (default: {nms.confidenceThreshold})\"\n )\n model.output_description[\"confidence\"] = 'Boxes × Class confidence (see user-defined metadata \"classes\")'\n model.output_description[\"coordinates\"] = \"Boxes × [x, y, width, height] (relative to image size)\"\n LOGGER.info(f\"{prefix} pipeline success\")\n return model\n\n def add_callback(self, event: str, callback):\n \"\"\"Append the given callback to the specified event.\"\"\"\n self.callbacks[event].append(callback)\n\n def run_callbacks(self, event: str):\n \"\"\"Execute all callbacks for a given event.\"\"\"\n for callback in self.callbacks.get(event, []):\n callback(self)", "chunk_type": "class", "name": "Exporter", "file_path": "ultralytics\\ultralytics\\engine\\exporter.py", "start_line": 203, "end_line": 1457, "start_col": 0, "end_col": 26, "parent_name": null, "docstring": "A class for exporting YOLO models to various formats.\n\nThis class provides functionality to export YOLO models to different formats including ONNX, TensorRT, CoreML,\nTensorFlow, and others. It handles format validation, device selection, model preparation, and the actual export\nprocess for each supported format.\n\nAttributes:\n args (SimpleNamespace): Configuration arguments for the exporter.\n callbacks (dict): Dictionary of callback functions for different export events.\n im (torch.Tensor): Input tensor for model inference during export.\n model (torch.nn.Module): The YOLO model to be exported.\n file (Path): Path to the model file being exported.\n output_shape (tuple): Shape of the model output tensor(s).\n pretty_name (str): Formatted model name for display purposes.\n metadata (dict): Model metadata including description, author, version, etc.\n device (torch.device): Device on which the model is loaded.\n imgsz (tuple): Input image size for the model.\n\nMethods:\n __call__: Main export method that handles the export process.\n get_int8_calibration_dataloader: Build dataloader for INT8 calibration.\n export_torchscript: Export model to TorchScript format.\n export_onnx: Export model to ONNX format.\n export_openvino: Export model to OpenVINO format.\n export_paddle: Export model to PaddlePaddle format.\n export_mnn: Export model to MNN format.\n export_ncnn: Export model to NCNN format.\n export_coreml: Export model to CoreML format.\n export_engine: Export model to TensorRT format.\n export_saved_model: Export model to TensorFlow SavedModel format.\n export_pb: Export model to TensorFlow GraphDef format.\n export_tflite: Export model to TensorFlow Lite format.\n export_edgetpu: Export model to Edge TPU format.\n export_tfjs: Export model to TensorFlow.js format.\n export_rknn: Export model to RKNN format.\n export_imx: Export model to IMX format.\n\nExamples:\n Export a YOLOv8 model to ONNX format\n >>> from ultralytics.engine.exporter import Exporter\n >>> exporter = Exporter()\n >>> exporter(model=\"yolov8n.pt\") # exports to yolov8n.onnx\n\n Export with specific arguments\n >>> args = {\"format\": \"onnx\", \"dynamic\": True, \"half\": True}\n >>> exporter = Exporter(overrides=args)\n >>> exporter(model=\"yolov8n.pt\")", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "json", "os", "re", "shutil", "subprocess", "time", "warnings", "copy.deepcopy", "datetime.datetime", "pathlib.Path", "numpy", "torch", "ultralytics.__version__", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.get_cfg", "ultralytics.data.build_dataloader", "ultralytics.data.dataset.YOLODataset", "ultralytics.data.utils.check_cls_dataset", "ultralytics.data.utils.check_det_dataset", "ultralytics.nn.autobackend.check_class_names", "ultralytics.nn.autobackend.default_class_names", "ultralytics.nn.modules.C2f", "ultralytics.nn.modules.Classify", "ultralytics.nn.modules.Detect", "ultralytics.nn.modules.RTDETRDecoder", "ultralytics.nn.tasks.ClassificationModel", "ultralytics.nn.tasks.DetectionModel", "ultralytics.nn.tasks.SegmentationModel", "ultralytics.nn.tasks.WorldModel", "ultralytics.utils.ARM64", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.IS_COLAB", "ultralytics.utils.IS_JETSON", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.MACOS", "ultralytics.utils.MACOS_VERSION", "ultralytics.utils.RKNN_CHIPS", "ultralytics.utils.ROOT", "ultralytics.utils.SETTINGS", "ultralytics.utils.WINDOWS", "ultralytics.utils.YAML", "ultralytics.utils.callbacks", "ultralytics.utils.colorstr", "ultralytics.utils.get_default_args", "ultralytics.utils.checks.check_imgsz", "ultralytics.utils.checks.check_is_path_safe", "ultralytics.utils.checks.check_requirements", "ultralytics.utils.checks.check_version", "ultralytics.utils.checks.is_intel", "ultralytics.utils.checks.is_sudo_available", "ultralytics.utils.downloads.attempt_download_asset", "ultralytics.utils.downloads.get_github_assets", "ultralytics.utils.downloads.safe_download", "ultralytics.utils.export.export_engine", "ultralytics.utils.export.export_onnx", "ultralytics.utils.files.file_size", "ultralytics.utils.files.spaces_in_path", "ultralytics.utils.ops.Profile", "ultralytics.utils.ops.nms_rotated", "ultralytics.utils.patches.arange_patch", "ultralytics.utils.torch_utils.TORCH_1_13", "ultralytics.utils.torch_utils.get_latest_opset", "ultralytics.utils.torch_utils.select_device", "onnx", "openvino", "x2paddle", "x2paddle.convert.pytorch2paddle", "MNN", "MNN.tools.mnnconvert", "ncnn", "coremltools", "onnx2tf", "tensorflow", "tensorflow.python.framework.convert_to_constants.convert_variables_to_constants_v2", "tensorflow", "tensorflow", "tensorflowjs", "rknn.api.RKNN", "model_compression_toolkit", "onnx", "edgemdt_tpc.get_target_platform_capabilities", "sony_custom_layers.pytorch.multiclass_nms_with_indices", "zipfile", "coremltools", "functools.partial", "torchvision.ops.nms", "difflib", "ultralytics.utils.torch_utils.FXModel", "torch.utils.mobile_optimizer.optimize_for_mobile", "nncf", "tensorrt", "tensorflow", "builtins", "PIL.Image", "ultralytics.utils.tal.make_anchors", "onnxslim", "tensorrt", "tensorflow", "coremltools.optimize.coreml" ], "chunk_id": "class_Exporter_764adf24" }, { "content": "class IOSDetectModel(torch.nn.Module):\n \"\"\"Wrap an Ultralytics YOLO model for Apple iOS CoreML export.\"\"\"\n\n def __init__(self, model, im):\n \"\"\"\n Initialize the IOSDetectModel class with a YOLO model and example image.\n\n Args:\n model (torch.nn.Module): The YOLO model to wrap.\n im (torch.Tensor): Example input tensor with shape (B, C, H, W).\n \"\"\"\n super().__init__()\n _, _, h, w = im.shape # batch, channel, height, width\n self.model = model\n self.nc = len(model.names) # number of classes\n if w == h:\n self.normalize = 1.0 / w # scalar\n else:\n self.normalize = torch.tensor([1.0 / w, 1.0 / h, 1.0 / w, 1.0 / h]) # broadcast (slower, smaller)\n\n def forward(self, x):\n \"\"\"Normalize predictions of object detection model with input size-dependent factors.\"\"\"\n xywh, cls = self.model(x)[0].transpose(0, 1).split((4, self.nc), 1)\n return cls, xywh * self.normalize # confidence (3780, 80), coordinates (3780, 4)", "chunk_type": "class", "name": "IOSDetectModel", "file_path": "ultralytics\\ultralytics\\engine\\exporter.py", "start_line": 1460, "end_line": 1483, "start_col": 0, "end_col": 41, "parent_name": null, "docstring": "Wrap an Ultralytics YOLO model for Apple iOS CoreML export.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "json", "os", "re", "shutil", "subprocess", "time", "warnings", "copy.deepcopy", "datetime.datetime", "pathlib.Path", "numpy", "torch", "ultralytics.__version__", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.get_cfg", "ultralytics.data.build_dataloader", "ultralytics.data.dataset.YOLODataset", "ultralytics.data.utils.check_cls_dataset", "ultralytics.data.utils.check_det_dataset", "ultralytics.nn.autobackend.check_class_names", "ultralytics.nn.autobackend.default_class_names", "ultralytics.nn.modules.C2f", "ultralytics.nn.modules.Classify", "ultralytics.nn.modules.Detect", "ultralytics.nn.modules.RTDETRDecoder", "ultralytics.nn.tasks.ClassificationModel", "ultralytics.nn.tasks.DetectionModel", "ultralytics.nn.tasks.SegmentationModel", "ultralytics.nn.tasks.WorldModel", "ultralytics.utils.ARM64", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.IS_COLAB", "ultralytics.utils.IS_JETSON", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.MACOS", "ultralytics.utils.MACOS_VERSION", "ultralytics.utils.RKNN_CHIPS", "ultralytics.utils.ROOT", "ultralytics.utils.SETTINGS", "ultralytics.utils.WINDOWS", "ultralytics.utils.YAML", "ultralytics.utils.callbacks", "ultralytics.utils.colorstr", "ultralytics.utils.get_default_args", "ultralytics.utils.checks.check_imgsz", "ultralytics.utils.checks.check_is_path_safe", "ultralytics.utils.checks.check_requirements", "ultralytics.utils.checks.check_version", "ultralytics.utils.checks.is_intel", "ultralytics.utils.checks.is_sudo_available", "ultralytics.utils.downloads.attempt_download_asset", "ultralytics.utils.downloads.get_github_assets", "ultralytics.utils.downloads.safe_download", "ultralytics.utils.export.export_engine", "ultralytics.utils.export.export_onnx", "ultralytics.utils.files.file_size", "ultralytics.utils.files.spaces_in_path", "ultralytics.utils.ops.Profile", "ultralytics.utils.ops.nms_rotated", "ultralytics.utils.patches.arange_patch", "ultralytics.utils.torch_utils.TORCH_1_13", "ultralytics.utils.torch_utils.get_latest_opset", "ultralytics.utils.torch_utils.select_device", "onnx", "openvino", "x2paddle", "x2paddle.convert.pytorch2paddle", "MNN", "MNN.tools.mnnconvert", "ncnn", "coremltools", "onnx2tf", "tensorflow", "tensorflow.python.framework.convert_to_constants.convert_variables_to_constants_v2", "tensorflow", "tensorflow", "tensorflowjs", "rknn.api.RKNN", "model_compression_toolkit", "onnx", "edgemdt_tpc.get_target_platform_capabilities", "sony_custom_layers.pytorch.multiclass_nms_with_indices", "zipfile", "coremltools", "functools.partial", "torchvision.ops.nms", "difflib", "ultralytics.utils.torch_utils.FXModel", "torch.utils.mobile_optimizer.optimize_for_mobile", "nncf", "tensorrt", "tensorflow", "builtins", "PIL.Image", "ultralytics.utils.tal.make_anchors", "onnxslim", "tensorrt", "tensorflow", "coremltools.optimize.coreml", "torch.nn.Module" ], "chunk_id": "class_IOSDetectModel_4201a8a7" }, { "content": "class NMSModel(torch.nn.Module):\n \"\"\"Model wrapper with embedded NMS for Detect, Segment, Pose and OBB.\"\"\"\n\n def __init__(self, model, args):\n \"\"\"\n Initialize the NMSModel.\n\n Args:\n model (torch.nn.Module): The model to wrap with NMS postprocessing.\n args (Namespace): The export arguments.\n \"\"\"\n super().__init__()\n self.model = model\n self.args = args\n self.obb = model.task == \"obb\"\n self.is_tf = self.args.format in frozenset({\"saved_model\", \"tflite\", \"tfjs\"})\n\n def forward(self, x):\n \"\"\"\n Perform inference with NMS post-processing. Supports Detect, Segment, OBB and Pose.\n\n Args:\n x (torch.Tensor): The preprocessed tensor with shape (N, 3, H, W).\n\n Returns:\n (torch.Tensor): List of detections, each an (N, max_det, 4 + 2 + extra_shape) Tensor where N is the\n number of detections after NMS.\n \"\"\"\n from functools import partial\n\n from torchvision.ops import nms\n\n preds = self.model(x)\n pred = preds[0] if isinstance(preds, tuple) else preds\n kwargs = dict(device=pred.device, dtype=pred.dtype)\n bs = pred.shape[0]\n pred = pred.transpose(-1, -2) # shape(1,84,6300) to shape(1,6300,84)\n extra_shape = pred.shape[-1] - (4 + len(self.model.names)) # extras from Segment, OBB, Pose\n if self.args.dynamic and self.args.batch > 1: # batch size needs to always be same due to loop unroll\n pad = torch.zeros(torch.max(torch.tensor(self.args.batch - bs), torch.tensor(0)), *pred.shape[1:], **kwargs)\n pred = torch.cat((pred, pad))\n boxes, scores, extras = pred.split([4, len(self.model.names), extra_shape], dim=2)\n scores, classes = scores.max(dim=-1)\n self.args.max_det = min(pred.shape[1], self.args.max_det) # in case num_anchors < max_det\n # (N, max_det, 4 coords + 1 class score + 1 class label + extra_shape).\n out = torch.zeros(pred.shape[0], self.args.max_det, boxes.shape[-1] + 2 + extra_shape, **kwargs)\n for i in range(bs):\n box, cls, score, extra = boxes[i], classes[i], scores[i], extras[i]\n mask = score > self.args.conf\n if self.is_tf:\n # TFLite GatherND error if mask is empty\n score *= mask\n # Explicit length otherwise reshape error, hardcoded to `self.args.max_det * 5`\n mask = score.topk(min(self.args.max_det * 5, score.shape[0])).indices\n box, score, cls, extra = box[mask], score[mask], cls[mask], extra[mask]\n nmsbox = box.clone()\n # `8` is the minimum value experimented to get correct NMS results for obb\n multiplier = 8 if self.obb else 1\n # Normalize boxes for NMS since large values for class offset causes issue with int8 quantization\n if self.args.format == \"tflite\": # TFLite is already normalized\n nmsbox *= multiplier\n else:\n nmsbox = multiplier * nmsbox / torch.tensor(x.shape[2:], **kwargs).max()\n if not self.args.agnostic_nms: # class-specific NMS\n end = 2 if self.obb else 4\n # fully explicit expansion otherwise reshape error\n # large max_wh causes issues when quantizing\n cls_offset = cls.reshape(-1, 1).expand(nmsbox.shape[0], end)\n offbox = nmsbox[:, :end] + cls_offset * multiplier\n nmsbox = torch.cat((offbox, nmsbox[:, end:]), dim=-1)\n nms_fn = (\n partial(\n nms_rotated,\n use_triu=not (\n self.is_tf\n or (self.args.opset or 14) < 14\n or (self.args.format == \"openvino\" and self.args.int8) # OpenVINO int8 error with triu\n ),\n )\n if self.obb\n else nms\n )\n keep = nms_fn(\n torch.cat([nmsbox, extra], dim=-1) if self.obb else nmsbox,\n score,\n self.args.iou,\n )[: self.args.max_det]\n dets = torch.cat(\n [box[keep], score[keep].view(-1, 1), cls[keep].view(-1, 1).to(out.dtype), extra[keep]], dim=-1\n )\n # Zero-pad to max_det size to avoid reshape error\n pad = (0, 0, 0, self.args.max_det - dets.shape[0])\n out[i] = torch.nn.functional.pad(dets, pad)\n return (out[:bs], preds[1]) if self.model.task == \"segment\" else out[:bs]", "chunk_type": "class", "name": "NMSModel", "file_path": "ultralytics\\ultralytics\\engine\\exporter.py", "start_line": 1486, "end_line": 1579, "start_col": 0, "end_col": 81, "parent_name": null, "docstring": "Model wrapper with embedded NMS for Detect, Segment, Pose and OBB.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "json", "os", "re", "shutil", "subprocess", "time", "warnings", "copy.deepcopy", "datetime.datetime", "pathlib.Path", "numpy", "torch", "ultralytics.__version__", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.get_cfg", "ultralytics.data.build_dataloader", "ultralytics.data.dataset.YOLODataset", "ultralytics.data.utils.check_cls_dataset", "ultralytics.data.utils.check_det_dataset", "ultralytics.nn.autobackend.check_class_names", "ultralytics.nn.autobackend.default_class_names", "ultralytics.nn.modules.C2f", "ultralytics.nn.modules.Classify", "ultralytics.nn.modules.Detect", "ultralytics.nn.modules.RTDETRDecoder", "ultralytics.nn.tasks.ClassificationModel", "ultralytics.nn.tasks.DetectionModel", "ultralytics.nn.tasks.SegmentationModel", "ultralytics.nn.tasks.WorldModel", "ultralytics.utils.ARM64", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.IS_COLAB", "ultralytics.utils.IS_JETSON", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.MACOS", "ultralytics.utils.MACOS_VERSION", "ultralytics.utils.RKNN_CHIPS", "ultralytics.utils.ROOT", "ultralytics.utils.SETTINGS", "ultralytics.utils.WINDOWS", "ultralytics.utils.YAML", "ultralytics.utils.callbacks", "ultralytics.utils.colorstr", "ultralytics.utils.get_default_args", "ultralytics.utils.checks.check_imgsz", "ultralytics.utils.checks.check_is_path_safe", "ultralytics.utils.checks.check_requirements", "ultralytics.utils.checks.check_version", "ultralytics.utils.checks.is_intel", "ultralytics.utils.checks.is_sudo_available", "ultralytics.utils.downloads.attempt_download_asset", "ultralytics.utils.downloads.get_github_assets", "ultralytics.utils.downloads.safe_download", "ultralytics.utils.export.export_engine", "ultralytics.utils.export.export_onnx", "ultralytics.utils.files.file_size", "ultralytics.utils.files.spaces_in_path", "ultralytics.utils.ops.Profile", "ultralytics.utils.ops.nms_rotated", "ultralytics.utils.patches.arange_patch", "ultralytics.utils.torch_utils.TORCH_1_13", "ultralytics.utils.torch_utils.get_latest_opset", "ultralytics.utils.torch_utils.select_device", "onnx", "openvino", "x2paddle", "x2paddle.convert.pytorch2paddle", "MNN", "MNN.tools.mnnconvert", "ncnn", "coremltools", "onnx2tf", "tensorflow", "tensorflow.python.framework.convert_to_constants.convert_variables_to_constants_v2", "tensorflow", "tensorflow", "tensorflowjs", "rknn.api.RKNN", "model_compression_toolkit", "onnx", "edgemdt_tpc.get_target_platform_capabilities", "sony_custom_layers.pytorch.multiclass_nms_with_indices", "zipfile", "coremltools", "functools.partial", "torchvision.ops.nms", "difflib", "ultralytics.utils.torch_utils.FXModel", "torch.utils.mobile_optimizer.optimize_for_mobile", "nncf", "tensorrt", "tensorflow", "builtins", "PIL.Image", "ultralytics.utils.tal.make_anchors", "onnxslim", "tensorrt", "tensorflow", "coremltools.optimize.coreml", "torch.nn.Module" ], "chunk_id": "class_NMSModel_20869e28" }, { "content": "import inspect", "chunk_type": "import", "name": "inspect", "file_path": "ultralytics\\ultralytics\\engine\\model.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 14, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_inspect_21956667" }, { "content": "from pathlib import Path", "chunk_type": "import", "name": "Path", "file_path": "ultralytics\\ultralytics\\engine\\model.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Path_5d3be677" }, { "content": "from typing import Any, Dict, List, Union", "chunk_type": "import", "name": "Any, Dict, List, Union", "file_path": "ultralytics\\ultralytics\\engine\\model.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 41, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Any, Dict, List, Union_473bbf8f" }, { "content": "import numpy as np", "chunk_type": "import", "name": "numpy", "file_path": "ultralytics\\ultralytics\\engine\\model.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_numpy_0fae645c" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\ultralytics\\engine\\model.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_3d17b27a" }, { "content": "from PIL import Image", "chunk_type": "import", "name": "Image", "file_path": "ultralytics\\ultralytics\\engine\\model.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 21, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Image_8c992b41" }, { "content": "from ultralytics.cfg import TASK2DATA, get_cfg, get_save_dir", "chunk_type": "import", "name": "TASK2DATA, get_cfg, get_save_dir", "file_path": "ultralytics\\ultralytics\\engine\\model.py", "start_line": 11, "end_line": 11, "start_col": 0, "end_col": 60, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_TASK2DATA, get_cfg, get_save_dir_b4a0e55e" }, { "content": "from ultralytics.engine.results import Results", "chunk_type": "import", "name": "Results", "file_path": "ultralytics\\ultralytics\\engine\\model.py", "start_line": 12, "end_line": 12, "start_col": 0, "end_col": 46, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Results_68eb8af7" }, { "content": "from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, yaml_model_load", "chunk_type": "import", "name": "attempt_load_one_weight, guess_model_task, yaml_model_load", "file_path": "ultralytics\\ultralytics\\engine\\model.py", "start_line": 13, "end_line": 13, "start_col": 0, "end_col": 91, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_attempt_load_one_weight, guess_model_task, yaml_model_load_cf9e31e5" }, { "content": "from ultralytics.utils import (\n ARGV,\n ASSETS,\n DEFAULT_CFG_DICT,\n LOGGER,\n RANK,\n SETTINGS,\n YAML,\n callbacks,\n checks,\n)", "chunk_type": "import", "name": "ARGV, ASSETS, DEFAULT_CFG_DICT, LOGGER, RANK, SETTINGS, YAML, callbacks, checks", "file_path": "ultralytics\\ultralytics\\engine\\model.py", "start_line": 14, "end_line": 24, "start_col": 0, "end_col": 1, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_ARGV, ASSETS, DEFAULT_CFG_DICT, LOGGER, RANK, SETTINGS, YAML, callbacks, checks_d937d73d" }, { "content": "class Model(torch.nn.Module):\n \"\"\"\n A base class for implementing YOLO models, unifying APIs across different model types.\n\n This class provides a common interface for various operations related to YOLO models, such as training,\n validation, prediction, exporting, and benchmarking. It handles different types of models, including those\n loaded from local files, Ultralytics HUB, or Triton Server.\n\n Attributes:\n callbacks (dict): A dictionary of callback functions for various events during model operations.\n predictor (BasePredictor): The predictor object used for making predictions.\n model (torch.nn.Module): The underlying PyTorch model.\n trainer (BaseTrainer): The trainer object used for training the model.\n ckpt (dict): The checkpoint data if the model is loaded from a *.pt file.\n cfg (str): The configuration of the model if loaded from a *.yaml file.\n ckpt_path (str): The path to the checkpoint file.\n overrides (dict): A dictionary of overrides for model configuration.\n metrics (dict): The latest training/validation metrics.\n session (HUBTrainingSession): The Ultralytics HUB session, if applicable.\n task (str): The type of task the model is intended for.\n model_name (str): The name of the model.\n\n Methods:\n __call__: Alias for the predict method, enabling the model instance to be callable.\n _new: Initialize a new model based on a configuration file.\n _load: Load a model from a checkpoint file.\n _check_is_pytorch_model: Ensure that the model is a PyTorch model.\n reset_weights: Reset the model's weights to their initial state.\n load: Load model weights from a specified file.\n save: Save the current state of the model to a file.\n info: Log or return information about the model.\n fuse: Fuse Conv2d and BatchNorm2d layers for optimized inference.\n predict: Perform object detection predictions.\n track: Perform object tracking.\n val: Validate the model on a dataset.\n benchmark: Benchmark the model on various export formats.\n export: Export the model to different formats.\n train: Train the model on a dataset.\n tune: Perform hyperparameter tuning.\n _apply: Apply a function to the model's tensors.\n add_callback: Add a callback function for an event.\n clear_callback: Clear all callbacks for an event.\n reset_callbacks: Reset all callbacks to their default functions.\n\n Examples:\n >>> from ultralytics import YOLO\n >>> model = YOLO(\"yolo11n.pt\")\n >>> results = model.predict(\"image.jpg\")\n >>> model.train(data=\"coco8.yaml\", epochs=3)\n >>> metrics = model.val()\n >>> model.export(format=\"onnx\")\n \"\"\"\n\n def __init__(\n self,\n model: Union[str, Path, \"Model\"] = \"yolo11n.pt\",\n task: str = None,\n verbose: bool = False,\n ) -> None:\n \"\"\"\n Initialize a new instance of the YOLO model class.\n\n This constructor sets up the model based on the provided model path or name. It handles various types of\n model sources, including local files, Ultralytics HUB models, and Triton Server models. The method\n initializes several important attributes of the model and prepares it for operations like training,\n prediction, or export.\n\n Args:\n model (str | Path | Model): Path or name of the model to load or create. Can be a local file path, a\n model name from Ultralytics HUB, a Triton Server model, or an already initialized Model instance.\n task (str, optional): The specific task for the model. If None, it will be inferred from the config.\n verbose (bool): If True, enables verbose output during the model's initialization and subsequent\n operations.\n\n Raises:\n FileNotFoundError: If the specified model file does not exist or is inaccessible.\n ValueError: If the model file or configuration is invalid or unsupported.\n ImportError: If required dependencies for specific model types (like HUB SDK) are not installed.\n\n Examples:\n >>> model = Model(\"yolo11n.pt\")\n >>> model = Model(\"path/to/model.yaml\", task=\"detect\")\n >>> model = Model(\"hub_model\", verbose=True)\n \"\"\"\n if isinstance(model, Model):\n self.__dict__ = model.__dict__ # accepts an already initialized Model\n return\n super().__init__()\n self.callbacks = callbacks.get_default_callbacks()\n self.predictor = None # reuse predictor\n self.model = None # model object\n self.trainer = None # trainer object\n self.ckpt = {} # if loaded from *.pt\n self.cfg = None # if loaded from *.yaml\n self.ckpt_path = None\n self.overrides = {} # overrides for trainer object\n self.metrics = None # validation/training metrics\n self.session = None # HUB session\n self.task = task # task type\n self.model_name = None # model name\n model = str(model).strip()\n\n # Check if Ultralytics HUB model from https://hub.ultralytics.com\n if self.is_hub_model(model):\n from ultralytics.hub import HUBTrainingSession\n\n # Fetch model from HUB\n checks.check_requirements(\"hub-sdk>=0.0.12\")\n session = HUBTrainingSession.create_session(model)\n model = session.model_file\n if session.train_args: # training sent from HUB\n self.session = session\n\n # Check if Triton Server model\n elif self.is_triton_model(model):\n self.model_name = self.model = model\n self.overrides[\"task\"] = task or \"detect\" # set `task=detect` if not explicitly set\n return\n\n # Load or create new YOLO model\n __import__(\"os\").environ[\"CUBLAS_WORKSPACE_CONFIG\"] = \":4096:8\" # to avoid deterministic warnings\n if str(model).endswith((\".yaml\", \".yml\")):\n self._new(model, task=task, verbose=verbose)\n else:\n self._load(model, task=task)\n\n # Delete super().training for accessing self.model.training\n del self.training\n\n def __call__(\n self,\n source: Union[str, Path, int, Image.Image, list, tuple, np.ndarray, torch.Tensor] = None,\n stream: bool = False,\n **kwargs: Any,\n ) -> list:\n \"\"\"\n Alias for the predict method, enabling the model instance to be callable for predictions.\n\n This method simplifies the process of making predictions by allowing the model instance to be called\n directly with the required arguments.\n\n Args:\n source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | List | Tuple): The source of\n the image(s) to make predictions on. Can be a file path, URL, PIL image, numpy array, PyTorch\n tensor, or a list/tuple of these.\n stream (bool): If True, treat the input source as a continuous stream for predictions.\n **kwargs (Any): Additional keyword arguments to configure the prediction process.\n\n Returns:\n (List[ultralytics.engine.results.Results]): A list of prediction results, each encapsulated in a\n Results object.\n\n Examples:\n >>> model = YOLO(\"yolo11n.pt\")\n >>> results = model(\"https://ultralytics.com/images/bus.jpg\")\n >>> for r in results:\n ... print(f\"Detected {len(r)} objects in image\")\n \"\"\"\n return self.predict(source, stream, **kwargs)\n\n @staticmethod\n def is_triton_model(model: str) -> bool:\n \"\"\"\n Check if the given model string is a Triton Server URL.\n\n This static method determines whether the provided model string represents a valid Triton Server URL by\n parsing its components using urllib.parse.urlsplit().\n\n Args:\n model (str): The model string to be checked.\n\n Returns:\n (bool): True if the model string is a valid Triton Server URL, False otherwise.\n\n Examples:\n >>> Model.is_triton_model(\"http://localhost:8000/v2/models/yolo11n\")\n True\n >>> Model.is_triton_model(\"yolo11n.pt\")\n False\n \"\"\"\n from urllib.parse import urlsplit\n\n url = urlsplit(model)\n return url.netloc and url.path and url.scheme in {\"http\", \"grpc\"}\n\n @staticmethod\n def is_hub_model(model: str) -> bool:\n \"\"\"\n Check if the provided model is an Ultralytics HUB model.\n\n This static method determines whether the given model string represents a valid Ultralytics HUB model\n identifier.\n\n Args:\n model (str): The model string to check.\n\n Returns:\n (bool): True if the model is a valid Ultralytics HUB model, False otherwise.\n\n Examples:\n >>> Model.is_hub_model(\"https://hub.ultralytics.com/models/MODEL\")\n True\n >>> Model.is_hub_model(\"yolo11n.pt\")\n False\n \"\"\"\n from ultralytics.hub import HUB_WEB_ROOT\n\n return model.startswith(f\"{HUB_WEB_ROOT}/models/\")\n\n def _new(self, cfg: str, task=None, model=None, verbose=False) -> None:\n \"\"\"\n Initialize a new model and infer the task type from model definitions.\n\n Creates a new model instance based on the provided configuration file. Loads the model configuration, infers\n the task type if not specified, and initializes the model using the appropriate class from the task map.\n\n Args:\n cfg (str): Path to the model configuration file in YAML format.\n task (str, optional): The specific task for the model. If None, it will be inferred from the config.\n model (torch.nn.Module, optional): A custom model instance. If provided, it will be used instead of\n creating a new one.\n verbose (bool): If True, displays model information during loading.\n\n Raises:\n ValueError: If the configuration file is invalid or the task cannot be inferred.\n ImportError: If the required dependencies for the specified task are not installed.\n\n Examples:\n >>> model = Model()\n >>> model._new(\"yolo11n.yaml\", task=\"detect\", verbose=True)\n \"\"\"\n cfg_dict = yaml_model_load(cfg)\n self.cfg = cfg\n self.task = task or guess_model_task(cfg_dict)\n self.model = (model or self._smart_load(\"model\"))(cfg_dict, verbose=verbose and RANK == -1) # build model\n self.overrides[\"model\"] = self.cfg\n self.overrides[\"task\"] = self.task\n\n # Below added to allow export from YAMLs\n self.model.args = {**DEFAULT_CFG_DICT, **self.overrides} # combine default and model args (prefer model args)\n self.model.task = self.task\n self.model_name = cfg\n\n def _load(self, weights: str, task=None) -> None:\n \"\"\"\n Load a model from a checkpoint file or initialize it from a weights file.\n\n This method handles loading models from either .pt checkpoint files or other weight file formats. It sets\n up the model, task, and related attributes based on the loaded weights.\n\n Args:\n weights (str): Path to the model weights file to be loaded.\n task (str, optional): The task associated with the model. If None, it will be inferred from the model.\n\n Raises:\n FileNotFoundError: If the specified weights file does not exist or is inaccessible.\n ValueError: If the weights file format is unsupported or invalid.\n\n Examples:\n >>> model = Model()\n >>> model._load(\"yolo11n.pt\")\n >>> model._load(\"path/to/weights.pth\", task=\"detect\")\n \"\"\"\n if weights.lower().startswith((\"https://\", \"http://\", \"rtsp://\", \"rtmp://\", \"tcp://\")):\n weights = checks.check_file(weights, download_dir=SETTINGS[\"weights_dir\"]) # download and return local file\n weights = checks.check_model_file_from_stem(weights) # add suffix, i.e. yolo11n -> yolo11n.pt\n\n if str(weights).rpartition(\".\")[-1] == \"pt\":\n self.model, self.ckpt = attempt_load_one_weight(weights)\n self.task = self.model.task\n self.overrides = self.model.args = self._reset_ckpt_args(self.model.args)\n self.ckpt_path = self.model.pt_path\n else:\n weights = checks.check_file(weights) # runs in all cases, not redundant with above call\n self.model, self.ckpt = weights, None\n self.task = task or guess_model_task(weights)\n self.ckpt_path = weights\n self.overrides[\"model\"] = weights\n self.overrides[\"task\"] = self.task\n self.model_name = weights\n\n def _check_is_pytorch_model(self) -> None:\n \"\"\"\n Check if the model is a PyTorch model and raise TypeError if it's not.\n\n This method verifies that the model is either a PyTorch module or a .pt file. It's used to ensure that\n certain operations that require a PyTorch model are only performed on compatible model types.\n\n Raises:\n TypeError: If the model is not a PyTorch module or a .pt file. The error message provides detailed\n information about supported model formats and operations.\n\n Examples:\n >>> model = Model(\"yolo11n.pt\")\n >>> model._check_is_pytorch_model() # No error raised\n >>> model = Model(\"yolo11n.onnx\")\n >>> model._check_is_pytorch_model() # Raises TypeError\n \"\"\"\n pt_str = isinstance(self.model, (str, Path)) and str(self.model).rpartition(\".\")[-1] == \"pt\"\n pt_module = isinstance(self.model, torch.nn.Module)\n if not (pt_module or pt_str):\n raise TypeError(\n f\"model='{self.model}' should be a *.pt PyTorch model to run this method, but is a different format. \"\n f\"PyTorch models can train, val, predict and export, i.e. 'model.train(data=...)', but exported \"\n f\"formats like ONNX, TensorRT etc. only support 'predict' and 'val' modes, \"\n f\"i.e. 'yolo predict model=yolo11n.onnx'.\\nTo run CUDA or MPS inference please pass the device \"\n f\"argument directly in your inference command, i.e. 'model.predict(source=..., device=0)'\"\n )\n\n def reset_weights(self) -> \"Model\":\n \"\"\"\n Reset the model's weights to their initial state.\n\n This method iterates through all modules in the model and resets their parameters if they have a\n 'reset_parameters' method. It also ensures that all parameters have 'requires_grad' set to True,\n enabling them to be updated during training.\n\n Returns:\n (Model): The instance of the class with reset weights.\n\n Raises:\n AssertionError: If the model is not a PyTorch model.\n\n Examples:\n >>> model = Model(\"yolo11n.pt\")\n >>> model.reset_weights()\n \"\"\"\n self._check_is_pytorch_model()\n for m in self.model.modules():\n if hasattr(m, \"reset_parameters\"):\n m.reset_parameters()\n for p in self.model.parameters():\n p.requires_grad = True\n return self\n\n def load(self, weights: Union[str, Path] = \"yolo11n.pt\") -> \"Model\":\n \"\"\"\n Load parameters from the specified weights file into the model.\n\n This method supports loading weights from a file or directly from a weights object. It matches parameters by\n name and shape and transfers them to the model.\n\n Args:\n weights (str | Path): Path to the weights file or a weights object.\n\n Returns:\n (Model): The instance of the class with loaded weights.\n\n Raises:\n AssertionError: If the model is not a PyTorch model.\n\n Examples:\n >>> model = Model()\n >>> model.load(\"yolo11n.pt\")\n >>> model.load(Path(\"path/to/weights.pt\"))\n \"\"\"\n self._check_is_pytorch_model()\n if isinstance(weights, (str, Path)):\n self.overrides[\"pretrained\"] = weights # remember the weights for DDP training\n weights, self.ckpt = attempt_load_one_weight(weights)\n self.model.load(weights)\n return self\n\n def save(self, filename: Union[str, Path] = \"saved_model.pt\") -> None:\n \"\"\"\n Save the current model state to a file.\n\n This method exports the model's checkpoint (ckpt) to the specified filename. It includes metadata such as\n the date, Ultralytics version, license information, and a link to the documentation.\n\n Args:\n filename (str | Path): The name of the file to save the model to.\n\n Raises:\n AssertionError: If the model is not a PyTorch model.\n\n Examples:\n >>> model = Model(\"yolo11n.pt\")\n >>> model.save(\"my_model.pt\")\n \"\"\"\n self._check_is_pytorch_model()\n from copy import deepcopy\n from datetime import datetime\n\n from ultralytics import __version__\n\n updates = {\n \"model\": deepcopy(self.model).half() if isinstance(self.model, torch.nn.Module) else self.model,\n \"date\": datetime.now().isoformat(),\n \"version\": __version__,\n \"license\": \"AGPL-3.0 License (https://ultralytics.com/license)\",\n \"docs\": \"https://docs.ultralytics.com\",\n }\n torch.save({**self.ckpt, **updates}, filename)\n\n def info(self, detailed: bool = False, verbose: bool = True):\n \"\"\"\n Display model information.\n\n This method provides an overview or detailed information about the model, depending on the arguments\n passed. It can control the verbosity of the output and return the information as a list.\n\n Args:\n detailed (bool): If True, shows detailed information about the model layers and parameters.\n verbose (bool): If True, prints the information. If False, returns the information as a list.\n\n Returns:\n (List[str]): A list of strings containing various types of information about the model, including\n model summary, layer details, and parameter counts. Empty if verbose is True.\n\n Examples:\n >>> model = Model(\"yolo11n.pt\")\n >>> model.info() # Prints model summary\n >>> info_list = model.info(detailed=True, verbose=False) # Returns detailed info as a list\n \"\"\"\n self._check_is_pytorch_model()\n return self.model.info(detailed=detailed, verbose=verbose)\n\n def fuse(self) -> None:\n \"\"\"\n Fuse Conv2d and BatchNorm2d layers in the model for optimized inference.\n\n This method iterates through the model's modules and fuses consecutive Conv2d and BatchNorm2d layers\n into a single layer. This fusion can significantly improve inference speed by reducing the number of\n operations and memory accesses required during forward passes.\n\n The fusion process typically involves folding the BatchNorm2d parameters (mean, variance, weight, and\n bias) into the preceding Conv2d layer's weights and biases. This results in a single Conv2d layer that\n performs both convolution and normalization in one step.\n\n Examples:\n >>> model = Model(\"yolo11n.pt\")\n >>> model.fuse()\n >>> # Model is now fused and ready for optimized inference\n \"\"\"\n self._check_is_pytorch_model()\n self.model.fuse()\n\n def embed(\n self,\n source: Union[str, Path, int, list, tuple, np.ndarray, torch.Tensor] = None,\n stream: bool = False,\n **kwargs: Any,\n ) -> list:\n \"\"\"\n Generate image embeddings based on the provided source.\n\n This method is a wrapper around the 'predict()' method, focusing on generating embeddings from an image\n source. It allows customization of the embedding process through various keyword arguments.\n\n Args:\n source (str | Path | int | List | Tuple | np.ndarray | torch.Tensor): The source of the image for\n generating embeddings. Can be a file path, URL, PIL image, numpy array, etc.\n stream (bool): If True, predictions are streamed.\n **kwargs (Any): Additional keyword arguments for configuring the embedding process.\n\n Returns:\n (List[torch.Tensor]): A list containing the image embeddings.\n\n Examples:\n >>> model = YOLO(\"yolo11n.pt\")\n >>> image = \"https://ultralytics.com/images/bus.jpg\"\n >>> embeddings = model.embed(image)\n >>> print(embeddings[0].shape)\n \"\"\"\n if not kwargs.get(\"embed\"):\n kwargs[\"embed\"] = [len(self.model.model) - 2] # embed second-to-last layer if no indices passed\n return self.predict(source, stream, **kwargs)\n\n def predict(\n self,\n source: Union[str, Path, int, Image.Image, list, tuple, np.ndarray, torch.Tensor] = None,\n stream: bool = False,\n predictor=None,\n **kwargs: Any,\n ) -> List[Results]:\n \"\"\"\n Perform predictions on the given image source using the YOLO model.\n\n This method facilitates the prediction process, allowing various configurations through keyword arguments.\n It supports predictions with custom predictors or the default predictor method. The method handles different\n types of image sources and can operate in a streaming mode.\n\n Args:\n source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | List | Tuple): The source\n of the image(s) to make predictions on. Accepts various types including file paths, URLs, PIL\n images, numpy arrays, and torch tensors.\n stream (bool): If True, treats the input source as a continuous stream for predictions.\n predictor (BasePredictor, optional): An instance of a custom predictor class for making predictions.\n If None, the method uses a default predictor.\n **kwargs (Any): Additional keyword arguments for configuring the prediction process.\n\n Returns:\n (List[ultralytics.engine.results.Results]): A list of prediction results, each encapsulated in a\n Results object.\n\n Examples:\n >>> model = YOLO(\"yolo11n.pt\")\n >>> results = model.predict(source=\"path/to/image.jpg\", conf=0.25)\n >>> for r in results:\n ... print(r.boxes.data) # print detection bounding boxes\n\n Notes:\n - If 'source' is not provided, it defaults to the ASSETS constant with a warning.\n - The method sets up a new predictor if not already present and updates its arguments with each call.\n - For SAM-type models, 'prompts' can be passed as a keyword argument.\n \"\"\"\n if source is None:\n source = \"https://ultralytics.com/images/boats.jpg\" if self.task == \"obb\" else ASSETS\n LOGGER.warning(f\"'source' is missing. Using 'source={source}'.\")\n\n is_cli = (ARGV[0].endswith(\"yolo\") or ARGV[0].endswith(\"ultralytics\")) and any(\n x in ARGV for x in (\"predict\", \"track\", \"mode=predict\", \"mode=track\")\n )\n\n custom = {\"conf\": 0.25, \"batch\": 1, \"save\": is_cli, \"mode\": \"predict\", \"rect\": True} # method defaults\n args = {**self.overrides, **custom, **kwargs} # highest priority args on the right\n prompts = args.pop(\"prompts\", None) # for SAM-type models\n\n if not self.predictor:\n self.predictor = (predictor or self._smart_load(\"predictor\"))(overrides=args, _callbacks=self.callbacks)\n self.predictor.setup_model(model=self.model, verbose=is_cli)\n else: # only update args if predictor is already setup\n self.predictor.args = get_cfg(self.predictor.args, args)\n if \"project\" in args or \"name\" in args:\n self.predictor.save_dir = get_save_dir(self.predictor.args)\n if prompts and hasattr(self.predictor, \"set_prompts\"): # for SAM-type models\n self.predictor.set_prompts(prompts)\n return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream)\n\n def track(\n self,\n source: Union[str, Path, int, list, tuple, np.ndarray, torch.Tensor] = None,\n stream: bool = False,\n persist: bool = False,\n **kwargs: Any,\n ) -> List[Results]:\n \"\"\"\n Conduct object tracking on the specified input source using the registered trackers.\n\n This method performs object tracking using the model's predictors and optionally registered trackers. It handles\n various input sources such as file paths or video streams, and supports customization through keyword arguments.\n The method registers trackers if not already present and can persist them between calls.\n\n Args:\n source (str | Path | int | List | Tuple | np.ndarray | torch.Tensor, optional): Input source for object\n tracking. Can be a file path, URL, or video stream.\n stream (bool): If True, treats the input source as a continuous video stream.\n persist (bool): If True, persists trackers between different calls to this method.\n **kwargs (Any): Additional keyword arguments for configuring the tracking process.\n\n Returns:\n (List[ultralytics.engine.results.Results]): A list of tracking results, each a Results object.\n\n Examples:\n >>> model = YOLO(\"yolo11n.pt\")\n >>> results = model.track(source=\"path/to/video.mp4\", show=True)\n >>> for r in results:\n ... print(r.boxes.id) # print tracking IDs\n\n Notes:\n - This method sets a default confidence threshold of 0.1 for ByteTrack-based tracking.\n - The tracking mode is explicitly set in the keyword arguments.\n - Batch size is set to 1 for tracking in videos.\n \"\"\"\n if not hasattr(self.predictor, \"trackers\"):\n from ultralytics.trackers import register_tracker\n\n register_tracker(self, persist)\n kwargs[\"conf\"] = kwargs.get(\"conf\") or 0.1 # ByteTrack-based method needs low confidence predictions as input\n kwargs[\"batch\"] = kwargs.get(\"batch\") or 1 # batch-size 1 for tracking in videos\n kwargs[\"mode\"] = \"track\"\n return self.predict(source=source, stream=stream, **kwargs)\n\n def val(\n self,\n validator=None,\n **kwargs: Any,\n ):\n \"\"\"\n Validate the model using a specified dataset and validation configuration.\n\n This method facilitates the model validation process, allowing for customization through various settings. It\n supports validation with a custom validator or the default validation approach. The method combines default\n configurations, method-specific defaults, and user-provided arguments to configure the validation process.\n\n Args:\n validator (ultralytics.engine.validator.BaseValidator, optional): An instance of a custom validator class\n for validating the model.\n **kwargs (Any): Arbitrary keyword arguments for customizing the validation process.\n\n Returns:\n (ultralytics.utils.metrics.DetMetrics): Validation metrics obtained from the validation process.\n\n Raises:\n AssertionError: If the model is not a PyTorch model.\n\n Examples:\n >>> model = YOLO(\"yolo11n.pt\")\n >>> results = model.val(data=\"coco8.yaml\", imgsz=640)\n >>> print(results.box.map) # Print mAP50-95\n \"\"\"\n custom = {\"rect\": True} # method defaults\n args = {**self.overrides, **custom, **kwargs, \"mode\": \"val\"} # highest priority args on the right\n\n validator = (validator or self._smart_load(\"validator\"))(args=args, _callbacks=self.callbacks)\n validator(model=self.model)\n self.metrics = validator.metrics\n return validator.metrics\n\n def benchmark(self, data=None, format=\"\", verbose=False, **kwargs: Any):\n \"\"\"\n Benchmark the model across various export formats to evaluate performance.\n\n This method assesses the model's performance in different export formats, such as ONNX, TorchScript, etc.\n It uses the 'benchmark' function from the ultralytics.utils.benchmarks module. The benchmarking is\n configured using a combination of default configuration values, model-specific arguments, method-specific\n defaults, and any additional user-provided keyword arguments.\n\n Args:\n data (str): Path to the dataset for benchmarking.\n verbose (bool): Whether to print detailed benchmark information.\n format (str): Export format name for specific benchmarking.\n **kwargs (Any): Arbitrary keyword arguments to customize the benchmarking process. Common options include:\n - imgsz (int | List[int]): Image size for benchmarking.\n - half (bool): Whether to use half-precision (FP16) mode.\n - int8 (bool): Whether to use int8 precision mode.\n - device (str): Device to run the benchmark on (e.g., 'cpu', 'cuda').\n\n Returns:\n (dict): A dictionary containing the results of the benchmarking process, including metrics for\n different export formats.\n\n Raises:\n AssertionError: If the model is not a PyTorch model.\n\n Examples:\n >>> model = YOLO(\"yolo11n.pt\")\n >>> results = model.benchmark(data=\"coco8.yaml\", imgsz=640, half=True)\n >>> print(results)\n \"\"\"\n self._check_is_pytorch_model()\n from ultralytics.utils.benchmarks import benchmark\n\n from .exporter import export_formats\n\n custom = {\"verbose\": False} # method defaults\n args = {**DEFAULT_CFG_DICT, **self.model.args, **custom, **kwargs, \"mode\": \"benchmark\"}\n fmts = export_formats()\n export_args = set(dict(zip(fmts[\"Argument\"], fmts[\"Arguments\"])).get(format, [])) - {\"batch\"}\n export_kwargs = {k: v for k, v in args.items() if k in export_args}\n return benchmark(\n model=self,\n data=data, # if no 'data' argument passed set data=None for default datasets\n imgsz=args[\"imgsz\"],\n device=args[\"device\"],\n verbose=verbose,\n format=format,\n **export_kwargs,\n )\n\n def export(\n self,\n **kwargs: Any,\n ) -> str:\n \"\"\"\n Export the model to a different format suitable for deployment.\n\n This method facilitates the export of the model to various formats (e.g., ONNX, TorchScript) for deployment\n purposes. It uses the 'Exporter' class for the export process, combining model-specific overrides, method\n defaults, and any additional arguments provided.\n\n Args:\n **kwargs (Any): Arbitrary keyword arguments to customize the export process. These are combined with\n the model's overrides and method defaults. Common arguments include:\n format (str): Export format (e.g., 'onnx', 'engine', 'coreml').\n half (bool): Export model in half-precision.\n int8 (bool): Export model in int8 precision.\n device (str): Device to run the export on.\n workspace (int): Maximum memory workspace size for TensorRT engines.\n nms (bool): Add Non-Maximum Suppression (NMS) module to model.\n simplify (bool): Simplify ONNX model.\n\n Returns:\n (str): The path to the exported model file.\n\n Raises:\n AssertionError: If the model is not a PyTorch model.\n ValueError: If an unsupported export format is specified.\n RuntimeError: If the export process fails due to errors.\n\n Examples:\n >>> model = YOLO(\"yolo11n.pt\")\n >>> model.export(format=\"onnx\", dynamic=True, simplify=True)\n 'path/to/exported/model.onnx'\n \"\"\"\n self._check_is_pytorch_model()\n from .exporter import Exporter\n\n custom = {\n \"imgsz\": self.model.args[\"imgsz\"],\n \"batch\": 1,\n \"data\": None,\n \"device\": None, # reset to avoid multi-GPU errors\n \"verbose\": False,\n } # method defaults\n args = {**self.overrides, **custom, **kwargs, \"mode\": \"export\"} # highest priority args on the right\n return Exporter(overrides=args, _callbacks=self.callbacks)(model=self.model)\n\n def train(\n self,\n trainer=None,\n **kwargs: Any,\n ):\n \"\"\"\n Train the model using the specified dataset and training configuration.\n\n This method facilitates model training with a range of customizable settings. It supports training with a\n custom trainer or the default training approach. The method handles scenarios such as resuming training\n from a checkpoint, integrating with Ultralytics HUB, and updating model and configuration after training.\n\n When using Ultralytics HUB, if the session has a loaded model, the method prioritizes HUB training\n arguments and warns if local arguments are provided. It checks for pip updates and combines default\n configurations, method-specific defaults, and user-provided arguments to configure the training process.\n\n Args:\n trainer (BaseTrainer, optional): Custom trainer instance for model training. If None, uses default.\n **kwargs (Any): Arbitrary keyword arguments for training configuration. Common options include:\n data (str): Path to dataset configuration file.\n epochs (int): Number of training epochs.\n batch (int): Batch size for training.\n imgsz (int): Input image size.\n device (str): Device to run training on (e.g., 'cuda', 'cpu').\n workers (int): Number of worker threads for data loading.\n optimizer (str): Optimizer to use for training.\n lr0 (float): Initial learning rate.\n patience (int): Epochs to wait for no observable improvement for early stopping of training.\n\n Returns:\n (Dict | None): Training metrics if available and training is successful; otherwise, None.\n\n Examples:\n >>> model = YOLO(\"yolo11n.pt\")\n >>> results = model.train(data=\"coco8.yaml\", epochs=3)\n \"\"\"\n self._check_is_pytorch_model()\n if hasattr(self.session, \"model\") and self.session.model.id: # Ultralytics HUB session with loaded model\n if any(kwargs):\n LOGGER.warning(\"using HUB training arguments, ignoring local training arguments.\")\n kwargs = self.session.train_args # overwrite kwargs\n\n checks.check_pip_update_available()\n\n if isinstance(kwargs.get(\"pretrained\", None), (str, Path)):\n self.load(kwargs[\"pretrained\"]) # load pretrained weights if provided\n overrides = YAML.load(checks.check_yaml(kwargs[\"cfg\"])) if kwargs.get(\"cfg\") else self.overrides\n custom = {\n # NOTE: handle the case when 'cfg' includes 'data'.\n \"data\": overrides.get(\"data\") or DEFAULT_CFG_DICT[\"data\"] or TASK2DATA[self.task],\n \"model\": self.overrides[\"model\"],\n \"task\": self.task,\n } # method defaults\n args = {**overrides, **custom, **kwargs, \"mode\": \"train\"} # highest priority args on the right\n if args.get(\"resume\"):\n args[\"resume\"] = self.ckpt_path\n\n self.trainer = (trainer or self._smart_load(\"trainer\"))(overrides=args, _callbacks=self.callbacks)\n if not args.get(\"resume\"): # manually set model only if not resuming\n self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)\n self.model = self.trainer.model\n\n self.trainer.hub_session = self.session # attach optional HUB session\n self.trainer.train()\n # Update model and cfg after training\n if RANK in {-1, 0}:\n ckpt = self.trainer.best if self.trainer.best.exists() else self.trainer.last\n self.model, self.ckpt = attempt_load_one_weight(ckpt)\n self.overrides = self.model.args\n self.metrics = getattr(self.trainer.validator, \"metrics\", None) # TODO: no metrics returned by DDP\n return self.metrics\n\n def tune(\n self,\n use_ray=False,\n iterations=10,\n *args: Any,\n **kwargs: Any,\n ):\n \"\"\"\n Conduct hyperparameter tuning for the model, with an option to use Ray Tune.\n\n This method supports two modes of hyperparameter tuning: using Ray Tune or a custom tuning method.\n When Ray Tune is enabled, it leverages the 'run_ray_tune' function from the ultralytics.utils.tuner module.\n Otherwise, it uses the internal 'Tuner' class for tuning. The method combines default, overridden, and\n custom arguments to configure the tuning process.\n\n Args:\n use_ray (bool): Whether to use Ray Tune for hyperparameter tuning. If False, uses internal tuning method.\n iterations (int): Number of tuning iterations to perform.\n *args (Any): Additional positional arguments to pass to the tuner.\n **kwargs (Any): Additional keyword arguments for tuning configuration. These are combined with model\n overrides and defaults to configure the tuning process.\n\n Returns:\n (dict): Results of the hyperparameter search, including best parameters and performance metrics.\n\n Raises:\n TypeError: If the model is not a PyTorch model.\n\n Examples:\n >>> model = YOLO(\"yolo11n.pt\")\n >>> results = model.tune(data=\"coco8.yaml\", iterations=5)\n >>> print(results)\n\n # Use Ray Tune for more advanced hyperparameter search\n >>> results = model.tune(use_ray=True, iterations=20, data=\"coco8.yaml\")\n \"\"\"\n self._check_is_pytorch_model()\n if use_ray:\n from ultralytics.utils.tuner import run_ray_tune\n\n return run_ray_tune(self, max_samples=iterations, *args, **kwargs)\n else:\n from .tuner import Tuner\n\n custom = {} # method defaults\n args = {**self.overrides, **custom, **kwargs, \"mode\": \"train\"} # highest priority args on the right\n return Tuner(args=args, _callbacks=self.callbacks)(model=self, iterations=iterations)\n\n def _apply(self, fn) -> \"Model\":\n \"\"\"\n Apply a function to model tensors that are not parameters or registered buffers.\n\n This method extends the functionality of the parent class's _apply method by additionally resetting the\n predictor and updating the device in the model's overrides. It's typically used for operations like\n moving the model to a different device or changing its precision.\n\n Args:\n fn (Callable): A function to be applied to the model's tensors. This is typically a method like\n to(), cpu(), cuda(), half(), or float().\n\n Returns:\n (Model): The model instance with the function applied and updated attributes.\n\n Raises:\n AssertionError: If the model is not a PyTorch model.\n\n Examples:\n >>> model = Model(\"yolo11n.pt\")\n >>> model = model._apply(lambda t: t.cuda()) # Move model to GPU\n \"\"\"\n self._check_is_pytorch_model()\n self = super()._apply(fn) # noqa\n self.predictor = None # reset predictor as device may have changed\n self.overrides[\"device\"] = self.device # was str(self.device) i.e. device(type='cuda', index=0) -> 'cuda:0'\n return self\n\n @property\n def names(self) -> Dict[int, str]:\n \"\"\"\n Retrieve the class names associated with the loaded model.\n\n This property returns the class names if they are defined in the model. It checks the class names for validity\n using the 'check_class_names' function from the ultralytics.nn.autobackend module. If the predictor is not\n initialized, it sets it up before retrieving the names.\n\n Returns:\n (Dict[int, str]): A dictionary of class names associated with the model, where keys are class indices and\n values are the corresponding class names.\n\n Raises:\n AttributeError: If the model or predictor does not have a 'names' attribute.\n\n Examples:\n >>> model = YOLO(\"yolo11n.pt\")\n >>> print(model.names)\n {0: 'person', 1: 'bicycle', 2: 'car', ...}\n \"\"\"\n from ultralytics.nn.autobackend import check_class_names\n\n if hasattr(self.model, \"names\"):\n return check_class_names(self.model.names)\n if not self.predictor: # export formats will not have predictor defined until predict() is called\n self.predictor = self._smart_load(\"predictor\")(overrides=self.overrides, _callbacks=self.callbacks)\n self.predictor.setup_model(model=self.model, verbose=False)\n return self.predictor.model.names\n\n @property\n def device(self) -> torch.device:\n \"\"\"\n Get the device on which the model's parameters are allocated.\n\n This property determines the device (CPU or GPU) where the model's parameters are currently stored. It is\n applicable only to models that are instances of torch.nn.Module.\n\n Returns:\n (torch.device): The device (CPU/GPU) of the model.\n\n Raises:\n AttributeError: If the model is not a torch.nn.Module instance.\n\n Examples:\n >>> model = YOLO(\"yolo11n.pt\")\n >>> print(model.device)\n device(type='cuda', index=0) # if CUDA is available\n >>> model = model.to(\"cpu\")\n >>> print(model.device)\n device(type='cpu')\n \"\"\"\n return next(self.model.parameters()).device if isinstance(self.model, torch.nn.Module) else None\n\n @property\n def transforms(self):\n \"\"\"\n Retrieve the transformations applied to the input data of the loaded model.\n\n This property returns the transformations if they are defined in the model. The transforms\n typically include preprocessing steps like resizing, normalization, and data augmentation\n that are applied to input data before it is fed into the model.\n\n Returns:\n (object | None): The transform object of the model if available, otherwise None.\n\n Examples:\n >>> model = YOLO(\"yolo11n.pt\")\n >>> transforms = model.transforms\n >>> if transforms:\n ... print(f\"Model transforms: {transforms}\")\n ... else:\n ... print(\"No transforms defined for this model.\")\n \"\"\"\n return self.model.transforms if hasattr(self.model, \"transforms\") else None\n\n def add_callback(self, event: str, func) -> None:\n \"\"\"\n Add a callback function for a specified event.\n\n This method allows registering custom callback functions that are triggered on specific events during\n model operations such as training or inference. Callbacks provide a way to extend and customize the\n behavior of the model at various stages of its lifecycle.\n\n Args:\n event (str): The name of the event to attach the callback to. Must be a valid event name recognized\n by the Ultralytics framework.\n func (Callable): The callback function to be registered. This function will be called when the\n specified event occurs.\n\n Raises:\n ValueError: If the event name is not recognized or is invalid.\n\n Examples:\n >>> def on_train_start(trainer):\n ... print(\"Training is starting!\")\n >>> model = YOLO(\"yolo11n.pt\")\n >>> model.add_callback(\"on_train_start\", on_train_start)\n >>> model.train(data=\"coco8.yaml\", epochs=1)\n \"\"\"\n self.callbacks[event].append(func)\n\n def clear_callback(self, event: str) -> None:\n \"\"\"\n Clear all callback functions registered for a specified event.\n\n This method removes all custom and default callback functions associated with the given event.\n It resets the callback list for the specified event to an empty list, effectively removing all\n registered callbacks for that event.\n\n Args:\n event (str): The name of the event for which to clear the callbacks. This should be a valid event name\n recognized by the Ultralytics callback system.\n\n Examples:\n >>> model = YOLO(\"yolo11n.pt\")\n >>> model.add_callback(\"on_train_start\", lambda: print(\"Training started\"))\n >>> model.clear_callback(\"on_train_start\")\n >>> # All callbacks for 'on_train_start' are now removed\n\n Notes:\n - This method affects both custom callbacks added by the user and default callbacks\n provided by the Ultralytics framework.\n - After calling this method, no callbacks will be executed for the specified event\n until new ones are added.\n - Use with caution as it removes all callbacks, including essential ones that might\n be required for proper functioning of certain operations.\n \"\"\"\n self.callbacks[event] = []\n\n def reset_callbacks(self) -> None:\n \"\"\"\n Reset all callbacks to their default functions.\n\n This method reinstates the default callback functions for all events, removing any custom callbacks that were\n previously added. It iterates through all default callback events and replaces the current callbacks with the\n default ones.\n\n The default callbacks are defined in the 'callbacks.default_callbacks' dictionary, which contains predefined\n functions for various events in the model's lifecycle, such as on_train_start, on_epoch_end, etc.\n\n This method is useful when you want to revert to the original set of callbacks after making custom\n modifications, ensuring consistent behavior across different runs or experiments.\n\n Examples:\n >>> model = YOLO(\"yolo11n.pt\")\n >>> model.add_callback(\"on_train_start\", custom_function)\n >>> model.reset_callbacks()\n # All callbacks are now reset to their default functions\n \"\"\"\n for event in callbacks.default_callbacks.keys():\n self.callbacks[event] = [callbacks.default_callbacks[event][0]]\n\n @staticmethod\n def _reset_ckpt_args(args: Dict[str, Any]) -> Dict[str, Any]:\n \"\"\"\n Reset specific arguments when loading a PyTorch model checkpoint.\n\n This method filters the input arguments dictionary to retain only a specific set of keys that are\n considered important for model loading. It's used to ensure that only relevant arguments are preserved\n when loading a model from a checkpoint, discarding any unnecessary or potentially conflicting settings.\n\n Args:\n args (dict): A dictionary containing various model arguments and settings.\n\n Returns:\n (dict): A new dictionary containing only the specified include keys from the input arguments.\n\n Examples:\n >>> original_args = {\"imgsz\": 640, \"data\": \"coco.yaml\", \"task\": \"detect\", \"batch\": 16, \"epochs\": 100}\n >>> reset_args = Model._reset_ckpt_args(original_args)\n >>> print(reset_args)\n {'imgsz': 640, 'data': 'coco.yaml', 'task': 'detect'}\n \"\"\"\n include = {\"imgsz\", \"data\", \"task\", \"single_cls\"} # only remember these arguments when loading a PyTorch model\n return {k: v for k, v in args.items() if k in include}\n\n # def __getattr__(self, attr):\n # \"\"\"Raises error if object has no requested attribute.\"\"\"\n # name = self.__class__.__name__\n # raise AttributeError(f\"'{name}' object has no attribute '{attr}'. See valid attributes below.\\n{self.__doc__}\")\n\n def _smart_load(self, key: str):\n \"\"\"\n Intelligently load the appropriate module based on the model task.\n\n This method dynamically selects and returns the correct module (model, trainer, validator, or predictor)\n based on the current task of the model and the provided key. It uses the task_map dictionary to determine\n the appropriate module to load for the specific task.\n\n Args:\n key (str): The type of module to load. Must be one of 'model', 'trainer', 'validator', or 'predictor'.\n\n Returns:\n (object): The loaded module class corresponding to the specified key and current task.\n\n Raises:\n NotImplementedError: If the specified key is not supported for the current task.\n\n Examples:\n >>> model = Model(task=\"detect\")\n >>> predictor_class = model._smart_load(\"predictor\")\n >>> trainer_class = model._smart_load(\"trainer\")\n \"\"\"\n try:\n return self.task_map[self.task][key]\n except Exception as e:\n name = self.__class__.__name__\n mode = inspect.stack()[1][3] # get the function name.\n raise NotImplementedError(f\"'{name}' model does not support '{mode}' mode for '{self.task}' task.\") from e\n\n @property\n def task_map(self) -> dict:\n \"\"\"\n Provide a mapping from model tasks to corresponding classes for different modes.\n\n This property method returns a dictionary that maps each supported task (e.g., detect, segment, classify)\n to a nested dictionary. The nested dictionary contains mappings for different operational modes\n (model, trainer, validator, predictor) to their respective class implementations.\n\n The mapping allows for dynamic loading of appropriate classes based on the model's task and the\n desired operational mode. This facilitates a flexible and extensible architecture for handling\n various tasks and modes within the Ultralytics framework.\n\n Returns:\n (Dict[str, Dict[str, Any]]): A dictionary mapping task names to nested dictionaries. Each nested dictionary\n contains mappings for 'model', 'trainer', 'validator', and 'predictor' keys to their respective class\n implementations for that task.\n\n Examples:\n >>> model = Model(\"yolo11n.pt\")\n >>> task_map = model.task_map\n >>> detect_predictor = task_map[\"detect\"][\"predictor\"]\n >>> segment_trainer = task_map[\"segment\"][\"trainer\"]\n \"\"\"\n raise NotImplementedError(\"Please provide task map for your model!\")\n\n def eval(self):\n \"\"\"\n Sets the model to evaluation mode.\n\n This method changes the model's mode to evaluation, which affects layers like dropout and batch normalization\n that behave differently during training and evaluation. In evaluation mode, these layers use running statistics\n rather than computing batch statistics, and dropout layers are disabled.\n\n Returns:\n (Model): The model instance with evaluation mode set.\n\n Examples:\n >>> model = YOLO(\"yolo11n.pt\")\n >>> model.eval()\n >>> # Model is now in evaluation mode for inference\n \"\"\"\n self.model.eval()\n return self\n\n def __getattr__(self, name):\n \"\"\"\n Enable accessing model attributes directly through the Model class.\n\n This method provides a way to access attributes of the underlying model directly through the Model class\n instance. It first checks if the requested attribute is 'model', in which case it returns the model from\n the module dictionary. Otherwise, it delegates the attribute lookup to the underlying model.\n\n Args:\n name (str): The name of the attribute to retrieve.\n\n Returns:\n (Any): The requested attribute value.\n\n Raises:\n AttributeError: If the requested attribute does not exist in the model.\n\n Examples:\n >>> model = YOLO(\"yolo11n.pt\")\n >>> print(model.stride) # Access model.stride attribute\n >>> print(model.names) # Access model.names attribute\n \"\"\"\n return self._modules[\"model\"] if name == \"model\" else getattr(self.model, name)", "chunk_type": "class", "name": "Model", "file_path": "ultralytics\\ultralytics\\engine\\model.py", "start_line": 27, "end_line": 1162, "start_col": 0, "end_col": 87, "parent_name": null, "docstring": "A base class for implementing YOLO models, unifying APIs across different model types.\n\nThis class provides a common interface for various operations related to YOLO models, such as training,\nvalidation, prediction, exporting, and benchmarking. It handles different types of models, including those\nloaded from local files, Ultralytics HUB, or Triton Server.\n\nAttributes:\n callbacks (dict): A dictionary of callback functions for various events during model operations.\n predictor (BasePredictor): The predictor object used for making predictions.\n model (torch.nn.Module): The underlying PyTorch model.\n trainer (BaseTrainer): The trainer object used for training the model.\n ckpt (dict): The checkpoint data if the model is loaded from a *.pt file.\n cfg (str): The configuration of the model if loaded from a *.yaml file.\n ckpt_path (str): The path to the checkpoint file.\n overrides (dict): A dictionary of overrides for model configuration.\n metrics (dict): The latest training/validation metrics.\n session (HUBTrainingSession): The Ultralytics HUB session, if applicable.\n task (str): The type of task the model is intended for.\n model_name (str): The name of the model.\n\nMethods:\n __call__: Alias for the predict method, enabling the model instance to be callable.\n _new: Initialize a new model based on a configuration file.\n _load: Load a model from a checkpoint file.\n _check_is_pytorch_model: Ensure that the model is a PyTorch model.\n reset_weights: Reset the model's weights to their initial state.\n load: Load model weights from a specified file.\n save: Save the current state of the model to a file.\n info: Log or return information about the model.\n fuse: Fuse Conv2d and BatchNorm2d layers for optimized inference.\n predict: Perform object detection predictions.\n track: Perform object tracking.\n val: Validate the model on a dataset.\n benchmark: Benchmark the model on various export formats.\n export: Export the model to different formats.\n train: Train the model on a dataset.\n tune: Perform hyperparameter tuning.\n _apply: Apply a function to the model's tensors.\n add_callback: Add a callback function for an event.\n clear_callback: Clear all callbacks for an event.\n reset_callbacks: Reset all callbacks to their default functions.\n\nExamples:\n >>> from ultralytics import YOLO\n >>> model = YOLO(\"yolo11n.pt\")\n >>> results = model.predict(\"image.jpg\")\n >>> model.train(data=\"coco8.yaml\", epochs=3)\n >>> metrics = model.val()\n >>> model.export(format=\"onnx\")", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "inspect", "pathlib.Path", "typing.Any", "typing.Dict", "typing.List", "typing.Union", "numpy", "torch", "PIL.Image", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.get_cfg", "ultralytics.cfg.get_save_dir", "ultralytics.engine.results.Results", "ultralytics.nn.tasks.attempt_load_one_weight", "ultralytics.nn.tasks.guess_model_task", "ultralytics.nn.tasks.yaml_model_load", "ultralytics.utils.ARGV", "ultralytics.utils.ASSETS", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.LOGGER", "ultralytics.utils.RANK", "ultralytics.utils.SETTINGS", "ultralytics.utils.YAML", "ultralytics.utils.callbacks", "ultralytics.utils.checks", "urllib.parse.urlsplit", "ultralytics.hub.HUB_WEB_ROOT", "copy.deepcopy", "datetime.datetime", "ultralytics.__version__", "ultralytics.utils.benchmarks.benchmark", "exporter.export_formats", "exporter.Exporter", "ultralytics.nn.autobackend.check_class_names", "ultralytics.hub.HUBTrainingSession", "ultralytics.trackers.register_tracker", "ultralytics.utils.tuner.run_ray_tune", "tuner.Tuner", "torch.nn.Module" ], "chunk_id": "class_Model_f6837c95" }, { "content": "import platform", "chunk_type": "import", "name": "platform", "file_path": "ultralytics\\ultralytics\\engine\\predictor.py", "start_line": 35, "end_line": 35, "start_col": 0, "end_col": 15, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_platform_f67c81a1" }, { "content": "import re", "chunk_type": "import", "name": "re", "file_path": "ultralytics\\ultralytics\\engine\\predictor.py", "start_line": 36, "end_line": 36, "start_col": 0, "end_col": 9, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_re_2565edf8" }, { "content": "import threading", "chunk_type": "import", "name": "threading", "file_path": "ultralytics\\ultralytics\\engine\\predictor.py", "start_line": 37, "end_line": 37, "start_col": 0, "end_col": 16, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_threading_637bbd71" }, { "content": "from pathlib import Path", "chunk_type": "import", "name": "Path", "file_path": "ultralytics\\ultralytics\\engine\\predictor.py", "start_line": 38, "end_line": 38, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Path_872d6f14" }, { "content": "from typing import Any, Dict, List, Optional, Union", "chunk_type": "import", "name": "Any, Dict, List, Optional, Union", "file_path": "ultralytics\\ultralytics\\engine\\predictor.py", "start_line": 39, "end_line": 39, "start_col": 0, "end_col": 51, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Any, Dict, List, Optional, Union_5f1e25a9" }, { "content": "import cv2", "chunk_type": "import", "name": "cv2", "file_path": "ultralytics\\ultralytics\\engine\\predictor.py", "start_line": 41, "end_line": 41, "start_col": 0, "end_col": 10, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_cv2_8685e6b8" }, { "content": "import numpy as np", "chunk_type": "import", "name": "numpy", "file_path": "ultralytics\\ultralytics\\engine\\predictor.py", "start_line": 42, "end_line": 42, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_numpy_4aa493bf" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\ultralytics\\engine\\predictor.py", "start_line": 43, "end_line": 43, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_d783ad48" }, { "content": "from ultralytics.cfg import get_cfg, get_save_dir", "chunk_type": "import", "name": "get_cfg, get_save_dir", "file_path": "ultralytics\\ultralytics\\engine\\predictor.py", "start_line": 45, "end_line": 45, "start_col": 0, "end_col": 49, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_get_cfg, get_save_dir_fee1085e" }, { "content": "from ultralytics.data import load_inference_source", "chunk_type": "import", "name": "load_inference_source", "file_path": "ultralytics\\ultralytics\\engine\\predictor.py", "start_line": 46, "end_line": 46, "start_col": 0, "end_col": 50, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_load_inference_source_e15aadd4" }, { "content": "from ultralytics.data.augment import LetterBox", "chunk_type": "import", "name": "LetterBox", "file_path": "ultralytics\\ultralytics\\engine\\predictor.py", "start_line": 47, "end_line": 47, "start_col": 0, "end_col": 46, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_LetterBox_9abf80df" }, { "content": "from ultralytics.nn.autobackend import AutoBackend", "chunk_type": "import", "name": "AutoBackend", "file_path": "ultralytics\\ultralytics\\engine\\predictor.py", "start_line": 48, "end_line": 48, "start_col": 0, "end_col": 50, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_AutoBackend_aed46707" }, { "content": "from ultralytics.utils import DEFAULT_CFG, LOGGER, MACOS, WINDOWS, callbacks, colorstr, ops", "chunk_type": "import", "name": "DEFAULT_CFG, LOGGER, MACOS, WINDOWS, callbacks, colorstr, ops", "file_path": "ultralytics\\ultralytics\\engine\\predictor.py", "start_line": 49, "end_line": 49, "start_col": 0, "end_col": 91, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_DEFAULT_CFG, LOGGER, MACOS, WINDOWS, callbacks, colorstr, ops_28f23817" }, { "content": "from ultralytics.utils.checks import check_imgsz, check_imshow", "chunk_type": "import", "name": "check_imgsz, check_imshow", "file_path": "ultralytics\\ultralytics\\engine\\predictor.py", "start_line": 50, "end_line": 50, "start_col": 0, "end_col": 62, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_check_imgsz, check_imshow_dc1e3026" }, { "content": "from ultralytics.utils.files import increment_path", "chunk_type": "import", "name": "increment_path", "file_path": "ultralytics\\ultralytics\\engine\\predictor.py", "start_line": 51, "end_line": 51, "start_col": 0, "end_col": 50, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_increment_path_d663f880" }, { "content": "from ultralytics.utils.torch_utils import select_device, smart_inference_mode", "chunk_type": "import", "name": "select_device, smart_inference_mode", "file_path": "ultralytics\\ultralytics\\engine\\predictor.py", "start_line": 52, "end_line": 52, "start_col": 0, "end_col": 77, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_select_device, smart_inference_mode_c32beaa2" }, { "content": "STREAM_WARNING = \"\"\"\ninference results will accumulate in RAM unless `stream=True` is passed, causing potential out-of-memory\nerrors for large sources or long-running streams and videos. See https://docs.ultralytics.com/modes/predict/ for help.\n\nExample:\n results = model(source=..., stream=True) # generator of Results objects\n for r in results:\n boxes = r.boxes # Boxes object for bbox outputs\n masks = r.masks # Masks object for segment masks outputs\n probs = r.probs # Class probabilities for classification outputs\n\"\"\"", "chunk_type": "variable", "name": "STREAM_WARNING", "file_path": "ultralytics\\ultralytics\\engine\\predictor.py", "start_line": 54, "end_line": 64, "start_col": 0, "end_col": 3, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_STREAM_WARNING_5437081c" }, { "content": "class BasePredictor:\n \"\"\"\n A base class for creating predictors.\n\n This class provides the foundation for prediction functionality, handling model setup, inference,\n and result processing across various input sources.\n\n Attributes:\n args (SimpleNamespace): Configuration for the predictor.\n save_dir (Path): Directory to save results.\n done_warmup (bool): Whether the predictor has finished setup.\n model (torch.nn.Module): Model used for prediction.\n data (dict): Data configuration.\n device (torch.device): Device used for prediction.\n dataset (Dataset): Dataset used for prediction.\n vid_writer (Dict[str, cv2.VideoWriter]): Dictionary of {save_path: video_writer} for saving video output.\n plotted_img (np.ndarray): Last plotted image.\n source_type (SimpleNamespace): Type of input source.\n seen (int): Number of images processed.\n windows (List[str]): List of window names for visualization.\n batch (tuple): Current batch data.\n results (List[Any]): Current batch results.\n transforms (callable): Image transforms for classification.\n callbacks (Dict[str, List[callable]]): Callback functions for different events.\n txt_path (Path): Path to save text results.\n _lock (threading.Lock): Lock for thread-safe inference.\n\n Methods:\n preprocess: Prepare input image before inference.\n inference: Run inference on a given image.\n postprocess: Process raw predictions into structured results.\n predict_cli: Run prediction for command line interface.\n setup_source: Set up input source and inference mode.\n stream_inference: Stream inference on input source.\n setup_model: Initialize and configure the model.\n write_results: Write inference results to files.\n save_predicted_images: Save prediction visualizations.\n show: Display results in a window.\n run_callbacks: Execute registered callbacks for an event.\n add_callback: Register a new callback function.\n \"\"\"\n\n def __init__(\n self,\n cfg=DEFAULT_CFG,\n overrides: Optional[Dict[str, Any]] = None,\n _callbacks: Optional[Dict[str, List[callable]]] = None,\n ):\n \"\"\"\n Initialize the BasePredictor class.\n\n Args:\n cfg (str | dict): Path to a configuration file or a configuration dictionary.\n overrides (dict, optional): Configuration overrides.\n _callbacks (dict, optional): Dictionary of callback functions.\n \"\"\"\n self.args = get_cfg(cfg, overrides)\n self.save_dir = get_save_dir(self.args)\n if self.args.conf is None:\n self.args.conf = 0.25 # default conf=0.25\n self.done_warmup = False\n if self.args.show:\n self.args.show = check_imshow(warn=True)\n\n # Usable if setup is done\n self.model = None\n self.data = self.args.data # data_dict\n self.imgsz = None\n self.device = None\n self.dataset = None\n self.vid_writer = {} # dict of {save_path: video_writer, ...}\n self.plotted_img = None\n self.source_type = None\n self.seen = 0\n self.windows = []\n self.batch = None\n self.results = None\n self.transforms = None\n self.callbacks = _callbacks or callbacks.get_default_callbacks()\n self.txt_path = None\n self._lock = threading.Lock() # for automatic thread-safe inference\n callbacks.add_integration_callbacks(self)\n\n def preprocess(self, im: Union[torch.Tensor, List[np.ndarray]]) -> torch.Tensor:\n \"\"\"\n Prepare input image before inference.\n\n Args:\n im (torch.Tensor | List[np.ndarray]): Images of shape (N, 3, H, W) for tensor, [(H, W, 3) x N] for list.\n\n Returns:\n (torch.Tensor): Preprocessed image tensor of shape (N, 3, H, W).\n \"\"\"\n not_tensor = not isinstance(im, torch.Tensor)\n if not_tensor:\n im = np.stack(self.pre_transform(im))\n if im.shape[-1] == 3:\n im = im[..., ::-1] # BGR to RGB\n im = im.transpose((0, 3, 1, 2)) # BHWC to BCHW, (n, 3, h, w)\n im = np.ascontiguousarray(im) # contiguous\n im = torch.from_numpy(im)\n\n im = im.to(self.device)\n im = im.half() if self.model.fp16 else im.float() # uint8 to fp16/32\n if not_tensor:\n im /= 255 # 0 - 255 to 0.0 - 1.0\n return im\n\n def inference(self, im: torch.Tensor, *args, **kwargs):\n \"\"\"Run inference on a given image using the specified model and arguments.\"\"\"\n visualize = (\n increment_path(self.save_dir / Path(self.batch[0][0]).stem, mkdir=True)\n if self.args.visualize and (not self.source_type.tensor)\n else False\n )\n return self.model(im, augment=self.args.augment, visualize=visualize, embed=self.args.embed, *args, **kwargs)\n\n def pre_transform(self, im: List[np.ndarray]) -> List[np.ndarray]:\n \"\"\"\n Pre-transform input image before inference.\n\n Args:\n im (List[np.ndarray]): List of images with shape [(H, W, 3) x N].\n\n Returns:\n (List[np.ndarray]): List of transformed images.\n \"\"\"\n same_shapes = len({x.shape for x in im}) == 1\n letterbox = LetterBox(\n self.imgsz,\n auto=same_shapes\n and self.args.rect\n and (self.model.pt or (getattr(self.model, \"dynamic\", False) and not self.model.imx)),\n stride=self.model.stride,\n )\n return [letterbox(image=x) for x in im]\n\n def postprocess(self, preds, img, orig_imgs):\n \"\"\"Post-process predictions for an image and return them.\"\"\"\n return preds\n\n def __call__(self, source=None, model=None, stream: bool = False, *args, **kwargs):\n \"\"\"\n Perform inference on an image or stream.\n\n Args:\n source (str | Path | List[str] | List[Path] | List[np.ndarray] | np.ndarray | torch.Tensor, optional):\n Source for inference.\n model (str | Path | torch.nn.Module, optional): Model for inference.\n stream (bool): Whether to stream the inference results. If True, returns a generator.\n *args (Any): Additional arguments for the inference method.\n **kwargs (Any): Additional keyword arguments for the inference method.\n\n Returns:\n (List[ultralytics.engine.results.Results] | generator): Results objects or generator of Results objects.\n \"\"\"\n self.stream = stream\n if stream:\n return self.stream_inference(source, model, *args, **kwargs)\n else:\n return list(self.stream_inference(source, model, *args, **kwargs)) # merge list of Result into one\n\n def predict_cli(self, source=None, model=None):\n \"\"\"\n Method used for Command Line Interface (CLI) prediction.\n\n This function is designed to run predictions using the CLI. It sets up the source and model, then processes\n the inputs in a streaming manner. This method ensures that no outputs accumulate in memory by consuming the\n generator without storing results.\n\n Args:\n source (str | Path | List[str] | List[Path] | List[np.ndarray] | np.ndarray | torch.Tensor, optional):\n Source for inference.\n model (str | Path | torch.nn.Module, optional): Model for inference.\n\n Note:\n Do not modify this function or remove the generator. The generator ensures that no outputs are\n accumulated in memory, which is critical for preventing memory issues during long-running predictions.\n \"\"\"\n gen = self.stream_inference(source, model)\n for _ in gen: # sourcery skip: remove-empty-nested-block, noqa\n pass\n\n def setup_source(self, source):\n \"\"\"\n Set up source and inference mode.\n\n Args:\n source (str | Path | List[str] | List[Path] | List[np.ndarray] | np.ndarray | torch.Tensor):\n Source for inference.\n \"\"\"\n self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2) # check image size\n self.dataset = load_inference_source(\n source=source,\n batch=self.args.batch,\n vid_stride=self.args.vid_stride,\n buffer=self.args.stream_buffer,\n channels=getattr(self.model, \"ch\", 3),\n )\n self.source_type = self.dataset.source_type\n if not getattr(self, \"stream\", True) and (\n self.source_type.stream\n or self.source_type.screenshot\n or len(self.dataset) > 1000 # many images\n or any(getattr(self.dataset, \"video_flag\", [False]))\n ): # videos\n LOGGER.warning(STREAM_WARNING)\n self.vid_writer = {}\n\n @smart_inference_mode()\n def stream_inference(self, source=None, model=None, *args, **kwargs):\n \"\"\"\n Stream real-time inference on camera feed and save results to file.\n\n Args:\n source (str | Path | List[str] | List[Path] | List[np.ndarray] | np.ndarray | torch.Tensor, optional):\n Source for inference.\n model (str | Path | torch.nn.Module, optional): Model for inference.\n *args (Any): Additional arguments for the inference method.\n **kwargs (Any): Additional keyword arguments for the inference method.\n\n Yields:\n (ultralytics.engine.results.Results): Results objects.\n \"\"\"\n if self.args.verbose:\n LOGGER.info(\"\")\n\n # Setup model\n if not self.model:\n self.setup_model(model)\n\n with self._lock: # for thread-safe inference\n # Setup source every time predict is called\n self.setup_source(source if source is not None else self.args.source)\n\n # Check if save_dir/ label file exists\n if self.args.save or self.args.save_txt:\n (self.save_dir / \"labels\" if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)\n\n # Warmup model\n if not self.done_warmup:\n self.model.warmup(\n imgsz=(1 if self.model.pt or self.model.triton else self.dataset.bs, self.model.ch, *self.imgsz)\n )\n self.done_warmup = True\n\n self.seen, self.windows, self.batch = 0, [], None\n profilers = (\n ops.Profile(device=self.device),\n ops.Profile(device=self.device),\n ops.Profile(device=self.device),\n )\n self.run_callbacks(\"on_predict_start\")\n for self.batch in self.dataset:\n self.run_callbacks(\"on_predict_batch_start\")\n paths, im0s, s = self.batch\n\n # Preprocess\n with profilers[0]:\n im = self.preprocess(im0s)\n\n # Inference\n with profilers[1]:\n preds = self.inference(im, *args, **kwargs)\n if self.args.embed:\n yield from [preds] if isinstance(preds, torch.Tensor) else preds # yield embedding tensors\n continue\n\n # Postprocess\n with profilers[2]:\n self.results = self.postprocess(preds, im, im0s)\n self.run_callbacks(\"on_predict_postprocess_end\")\n\n # Visualize, save, write results\n n = len(im0s)\n try:\n for i in range(n):\n self.seen += 1\n self.results[i].speed = {\n \"preprocess\": profilers[0].dt * 1e3 / n,\n \"inference\": profilers[1].dt * 1e3 / n,\n \"postprocess\": profilers[2].dt * 1e3 / n,\n }\n if self.args.verbose or self.args.save or self.args.save_txt or self.args.show:\n s[i] += self.write_results(i, Path(paths[i]), im, s)\n except StopIteration:\n break\n\n # Print batch results\n if self.args.verbose:\n LOGGER.info(\"\\n\".join(s))\n\n self.run_callbacks(\"on_predict_batch_end\")\n yield from self.results\n\n # Release assets\n for v in self.vid_writer.values():\n if isinstance(v, cv2.VideoWriter):\n v.release()\n\n if self.args.show:\n cv2.destroyAllWindows() # close any open windows\n\n # Print final results\n if self.args.verbose and self.seen:\n t = tuple(x.t / self.seen * 1e3 for x in profilers) # speeds per image\n LOGGER.info(\n f\"Speed: %.1fms preprocess, %.1fms inference, %.1fms postprocess per image at shape \"\n f\"{(min(self.args.batch, self.seen), getattr(self.model, 'ch', 3), *im.shape[2:])}\" % t\n )\n if self.args.save or self.args.save_txt or self.args.save_crop:\n nl = len(list(self.save_dir.glob(\"labels/*.txt\"))) # number of labels\n s = f\"\\n{nl} label{'s' * (nl > 1)} saved to {self.save_dir / 'labels'}\" if self.args.save_txt else \"\"\n LOGGER.info(f\"Results saved to {colorstr('bold', self.save_dir)}{s}\")\n self.run_callbacks(\"on_predict_end\")\n\n def setup_model(self, model, verbose: bool = True):\n \"\"\"\n Initialize YOLO model with given parameters and set it to evaluation mode.\n\n Args:\n model (str | Path | torch.nn.Module, optional): Model to load or use.\n verbose (bool): Whether to print verbose output.\n \"\"\"\n self.model = AutoBackend(\n weights=model or self.args.model,\n device=select_device(self.args.device, verbose=verbose),\n dnn=self.args.dnn,\n data=self.args.data,\n fp16=self.args.half,\n batch=self.args.batch,\n fuse=True,\n verbose=verbose,\n )\n\n self.device = self.model.device # update device\n self.args.half = self.model.fp16 # update half\n if hasattr(self.model, \"imgsz\") and not getattr(self.model, \"dynamic\", False):\n self.args.imgsz = self.model.imgsz # reuse imgsz from export metadata\n self.model.eval()\n\n def write_results(self, i: int, p: Path, im: torch.Tensor, s: List[str]) -> str:\n \"\"\"\n Write inference results to a file or directory.\n\n Args:\n i (int): Index of the current image in the batch.\n p (Path): Path to the current image.\n im (torch.Tensor): Preprocessed image tensor.\n s (List[str]): List of result strings.\n\n Returns:\n (str): String with result information.\n \"\"\"\n string = \"\" # print string\n if len(im.shape) == 3:\n im = im[None] # expand for batch dim\n if self.source_type.stream or self.source_type.from_img or self.source_type.tensor: # batch_size >= 1\n string += f\"{i}: \"\n frame = self.dataset.count\n else:\n match = re.search(r\"frame (\\d+)/\", s[i])\n frame = int(match[1]) if match else None # 0 if frame undetermined\n\n self.txt_path = self.save_dir / \"labels\" / (p.stem + (\"\" if self.dataset.mode == \"image\" else f\"_{frame}\"))\n string += \"{:g}x{:g} \".format(*im.shape[2:])\n result = self.results[i]\n result.save_dir = self.save_dir.__str__() # used in other locations\n string += f\"{result.verbose()}{result.speed['inference']:.1f}ms\"\n\n # Add predictions to image\n if self.args.save or self.args.show:\n self.plotted_img = result.plot(\n line_width=self.args.line_width,\n boxes=self.args.show_boxes,\n conf=self.args.show_conf,\n labels=self.args.show_labels,\n im_gpu=None if self.args.retina_masks else im[i],\n )\n\n # Save results\n if self.args.save_txt:\n result.save_txt(f\"{self.txt_path}.txt\", save_conf=self.args.save_conf)\n if self.args.save_crop:\n result.save_crop(save_dir=self.save_dir / \"crops\", file_name=self.txt_path.stem)\n if self.args.show:\n self.show(str(p))\n if self.args.save:\n self.save_predicted_images(self.save_dir / p.name, frame)\n\n return string\n\n def save_predicted_images(self, save_path: Path, frame: int = 0):\n \"\"\"\n Save video predictions as mp4 or images as jpg at specified path.\n\n Args:\n save_path (Path): Path to save the results.\n frame (int): Frame number for video mode.\n \"\"\"\n im = self.plotted_img\n\n # Save videos and streams\n if self.dataset.mode in {\"stream\", \"video\"}:\n fps = self.dataset.fps if self.dataset.mode == \"video\" else 30\n frames_path = self.save_dir / f\"{save_path.stem}_frames\" # save frames to a separate directory\n if save_path not in self.vid_writer: # new video\n if self.args.save_frames:\n Path(frames_path).mkdir(parents=True, exist_ok=True)\n suffix, fourcc = (\".mp4\", \"avc1\") if MACOS else (\".avi\", \"WMV2\") if WINDOWS else (\".avi\", \"MJPG\")\n self.vid_writer[save_path] = cv2.VideoWriter(\n filename=str(Path(save_path).with_suffix(suffix)),\n fourcc=cv2.VideoWriter_fourcc(*fourcc),\n fps=fps, # integer required, floats produce error in MP4 codec\n frameSize=(im.shape[1], im.shape[0]), # (width, height)\n )\n\n # Save video\n self.vid_writer[save_path].write(im)\n if self.args.save_frames:\n cv2.imwrite(f\"{frames_path}/{save_path.stem}_{frame}.jpg\", im)\n\n # Save images\n else:\n cv2.imwrite(str(save_path.with_suffix(\".jpg\")), im) # save to JPG for best support\n\n def show(self, p: str = \"\"):\n \"\"\"Display an image in a window.\"\"\"\n im = self.plotted_img\n if platform.system() == \"Linux\" and p not in self.windows:\n self.windows.append(p)\n cv2.namedWindow(p, cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux)\n cv2.resizeWindow(p, im.shape[1], im.shape[0]) # (width, height)\n cv2.imshow(p, im)\n if cv2.waitKey(300 if self.dataset.mode == \"image\" else 1) & 0xFF == ord(\"q\"): # 300ms if image; else 1ms\n raise StopIteration\n\n def run_callbacks(self, event: str):\n \"\"\"Run all registered callbacks for a specific event.\"\"\"\n for callback in self.callbacks.get(event, []):\n callback(self)\n\n def add_callback(self, event: str, func: callable):\n \"\"\"Add a callback function for a specific event.\"\"\"\n self.callbacks[event].append(func)", "chunk_type": "class", "name": "BasePredictor", "file_path": "ultralytics\\ultralytics\\engine\\predictor.py", "start_line": 67, "end_line": 511, "start_col": 0, "end_col": 42, "parent_name": null, "docstring": "A base class for creating predictors.\n\nThis class provides the foundation for prediction functionality, handling model setup, inference,\nand result processing across various input sources.\n\nAttributes:\n args (SimpleNamespace): Configuration for the predictor.\n save_dir (Path): Directory to save results.\n done_warmup (bool): Whether the predictor has finished setup.\n model (torch.nn.Module): Model used for prediction.\n data (dict): Data configuration.\n device (torch.device): Device used for prediction.\n dataset (Dataset): Dataset used for prediction.\n vid_writer (Dict[str, cv2.VideoWriter]): Dictionary of {save_path: video_writer} for saving video output.\n plotted_img (np.ndarray): Last plotted image.\n source_type (SimpleNamespace): Type of input source.\n seen (int): Number of images processed.\n windows (List[str]): List of window names for visualization.\n batch (tuple): Current batch data.\n results (List[Any]): Current batch results.\n transforms (callable): Image transforms for classification.\n callbacks (Dict[str, List[callable]]): Callback functions for different events.\n txt_path (Path): Path to save text results.\n _lock (threading.Lock): Lock for thread-safe inference.\n\nMethods:\n preprocess: Prepare input image before inference.\n inference: Run inference on a given image.\n postprocess: Process raw predictions into structured results.\n predict_cli: Run prediction for command line interface.\n setup_source: Set up input source and inference mode.\n stream_inference: Stream inference on input source.\n setup_model: Initialize and configure the model.\n write_results: Write inference results to files.\n save_predicted_images: Save prediction visualizations.\n show: Display results in a window.\n run_callbacks: Execute registered callbacks for an event.\n add_callback: Register a new callback function.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "platform", "re", "threading", "pathlib.Path", "typing.Any", "typing.Dict", "typing.List", "typing.Optional", "typing.Union", "cv2", "numpy", "torch", "ultralytics.cfg.get_cfg", "ultralytics.cfg.get_save_dir", "ultralytics.data.load_inference_source", "ultralytics.data.augment.LetterBox", "ultralytics.nn.autobackend.AutoBackend", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.LOGGER", "ultralytics.utils.MACOS", "ultralytics.utils.WINDOWS", "ultralytics.utils.callbacks", "ultralytics.utils.colorstr", "ultralytics.utils.ops", "ultralytics.utils.checks.check_imgsz", "ultralytics.utils.checks.check_imshow", "ultralytics.utils.files.increment_path", "ultralytics.utils.torch_utils.select_device", "ultralytics.utils.torch_utils.smart_inference_mode" ], "chunk_id": "class_BasePredictor_8e60b35e" }, { "content": "from copy import deepcopy", "chunk_type": "import", "name": "deepcopy", "file_path": "ultralytics\\ultralytics\\engine\\results.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 25, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_deepcopy_f5cd86c9" }, { "content": "from functools import lru_cache", "chunk_type": "import", "name": "lru_cache", "file_path": "ultralytics\\ultralytics\\engine\\results.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 31, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_lru_cache_d6748b64" }, { "content": "from pathlib import Path", "chunk_type": "import", "name": "Path", "file_path": "ultralytics\\ultralytics\\engine\\results.py", "start_line": 10, "end_line": 10, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Path_5b58cb42" }, { "content": "from typing import Any, Dict, List, Optional, Tuple, Union", "chunk_type": "import", "name": "Any, Dict, List, Optional, Tuple, Union", "file_path": "ultralytics\\ultralytics\\engine\\results.py", "start_line": 11, "end_line": 11, "start_col": 0, "end_col": 58, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Any, Dict, List, Optional, Tuple, Union_492ef0e2" }, { "content": "import numpy as np", "chunk_type": "import", "name": "numpy", "file_path": "ultralytics\\ultralytics\\engine\\results.py", "start_line": 13, "end_line": 13, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_numpy_c8a11aac" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\ultralytics\\engine\\results.py", "start_line": 14, "end_line": 14, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_fd3087b9" }, { "content": "from ultralytics.data.augment import LetterBox", "chunk_type": "import", "name": "LetterBox", "file_path": "ultralytics\\ultralytics\\engine\\results.py", "start_line": 16, "end_line": 16, "start_col": 0, "end_col": 46, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_LetterBox_2ae4ac54" }, { "content": "from ultralytics.utils import LOGGER, DataExportMixin, SimpleClass, ops", "chunk_type": "import", "name": "LOGGER, DataExportMixin, SimpleClass, ops", "file_path": "ultralytics\\ultralytics\\engine\\results.py", "start_line": 17, "end_line": 17, "start_col": 0, "end_col": 71, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_LOGGER, DataExportMixin, SimpleClass, ops_43c52ecb" }, { "content": "from ultralytics.utils.plotting import Annotator, colors, save_one_box", "chunk_type": "import", "name": "Annotator, colors, save_one_box", "file_path": "ultralytics\\ultralytics\\engine\\results.py", "start_line": 18, "end_line": 18, "start_col": 0, "end_col": 70, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Annotator, colors, save_one_box_47e29d4a" }, { "content": "class BaseTensor(SimpleClass):\n \"\"\"\n Base tensor class with additional methods for easy manipulation and device handling.\n\n This class provides a foundation for tensor-like objects with device management capabilities,\n supporting both PyTorch tensors and NumPy arrays. It includes methods for moving data between\n devices and converting between tensor types.\n\n Attributes:\n data (torch.Tensor | np.ndarray): Prediction data such as bounding boxes, masks, or keypoints.\n orig_shape (Tuple[int, int]): Original shape of the image, typically in the format (height, width).\n\n Methods:\n cpu: Return a copy of the tensor stored in CPU memory.\n numpy: Return a copy of the tensor as a numpy array.\n cuda: Move the tensor to GPU memory, returning a new instance if necessary.\n to: Return a copy of the tensor with the specified device and dtype.\n\n Examples:\n >>> import torch\n >>> data = torch.tensor([[1, 2, 3], [4, 5, 6]])\n >>> orig_shape = (720, 1280)\n >>> base_tensor = BaseTensor(data, orig_shape)\n >>> cpu_tensor = base_tensor.cpu()\n >>> numpy_array = base_tensor.numpy()\n >>> gpu_tensor = base_tensor.cuda()\n \"\"\"\n\n def __init__(self, data: Union[torch.Tensor, np.ndarray], orig_shape: Tuple[int, int]) -> None:\n \"\"\"\n Initialize BaseTensor with prediction data and the original shape of the image.\n\n Args:\n data (torch.Tensor | np.ndarray): Prediction data such as bounding boxes, masks, or keypoints.\n orig_shape (Tuple[int, int]): Original shape of the image in (height, width) format.\n\n Examples:\n >>> import torch\n >>> data = torch.tensor([[1, 2, 3], [4, 5, 6]])\n >>> orig_shape = (720, 1280)\n >>> base_tensor = BaseTensor(data, orig_shape)\n \"\"\"\n assert isinstance(data, (torch.Tensor, np.ndarray)), \"data must be torch.Tensor or np.ndarray\"\n self.data = data\n self.orig_shape = orig_shape\n\n @property\n def shape(self) -> Tuple[int, ...]:\n \"\"\"\n Return the shape of the underlying data tensor.\n\n Returns:\n (Tuple[int, ...]): The shape of the data tensor.\n\n Examples:\n >>> data = torch.rand(100, 4)\n >>> base_tensor = BaseTensor(data, orig_shape=(720, 1280))\n >>> print(base_tensor.shape)\n (100, 4)\n \"\"\"\n return self.data.shape\n\n def cpu(self):\n \"\"\"\n Return a copy of the tensor stored in CPU memory.\n\n Returns:\n (BaseTensor): A new BaseTensor object with the data tensor moved to CPU memory.\n\n Examples:\n >>> data = torch.tensor([[1, 2, 3], [4, 5, 6]]).cuda()\n >>> base_tensor = BaseTensor(data, orig_shape=(720, 1280))\n >>> cpu_tensor = base_tensor.cpu()\n >>> isinstance(cpu_tensor, BaseTensor)\n True\n >>> cpu_tensor.data.device\n device(type='cpu')\n \"\"\"\n return self if isinstance(self.data, np.ndarray) else self.__class__(self.data.cpu(), self.orig_shape)\n\n def numpy(self):\n \"\"\"\n Return a copy of the tensor as a numpy array.\n\n Returns:\n (np.ndarray): A numpy array containing the same data as the original tensor.\n\n Examples:\n >>> data = torch.tensor([[1, 2, 3], [4, 5, 6]])\n >>> orig_shape = (720, 1280)\n >>> base_tensor = BaseTensor(data, orig_shape)\n >>> numpy_array = base_tensor.numpy()\n >>> print(type(numpy_array))\n \n \"\"\"\n return self if isinstance(self.data, np.ndarray) else self.__class__(self.data.numpy(), self.orig_shape)\n\n def cuda(self):\n \"\"\"\n Move the tensor to GPU memory.\n\n Returns:\n (BaseTensor): A new BaseTensor instance with the data moved to GPU memory if it's not already a\n numpy array, otherwise returns self.\n\n Examples:\n >>> import torch\n >>> from ultralytics.engine.results import BaseTensor\n >>> data = torch.tensor([[1, 2, 3], [4, 5, 6]])\n >>> base_tensor = BaseTensor(data, orig_shape=(720, 1280))\n >>> gpu_tensor = base_tensor.cuda()\n >>> print(gpu_tensor.data.device)\n cuda:0\n \"\"\"\n return self.__class__(torch.as_tensor(self.data).cuda(), self.orig_shape)\n\n def to(self, *args, **kwargs):\n \"\"\"\n Return a copy of the tensor with the specified device and dtype.\n\n Args:\n *args (Any): Variable length argument list to be passed to torch.Tensor.to().\n **kwargs (Any): Arbitrary keyword arguments to be passed to torch.Tensor.to().\n\n Returns:\n (BaseTensor): A new BaseTensor instance with the data moved to the specified device and/or dtype.\n\n Examples:\n >>> base_tensor = BaseTensor(torch.randn(3, 4), orig_shape=(480, 640))\n >>> cuda_tensor = base_tensor.to(\"cuda\")\n >>> float16_tensor = base_tensor.to(dtype=torch.float16)\n \"\"\"\n return self.__class__(torch.as_tensor(self.data).to(*args, **kwargs), self.orig_shape)\n\n def __len__(self) -> int:\n \"\"\"\n Return the length of the underlying data tensor.\n\n Returns:\n (int): The number of elements in the first dimension of the data tensor.\n\n Examples:\n >>> data = torch.tensor([[1, 2, 3], [4, 5, 6]])\n >>> base_tensor = BaseTensor(data, orig_shape=(720, 1280))\n >>> len(base_tensor)\n 2\n \"\"\"\n return len(self.data)\n\n def __getitem__(self, idx):\n \"\"\"\n Return a new BaseTensor instance containing the specified indexed elements of the data tensor.\n\n Args:\n idx (int | List[int] | torch.Tensor): Index or indices to select from the data tensor.\n\n Returns:\n (BaseTensor): A new BaseTensor instance containing the indexed data.\n\n Examples:\n >>> data = torch.tensor([[1, 2, 3], [4, 5, 6]])\n >>> base_tensor = BaseTensor(data, orig_shape=(720, 1280))\n >>> result = base_tensor[0] # Select the first row\n >>> print(result.data)\n tensor([1, 2, 3])\n \"\"\"\n return self.__class__(self.data[idx], self.orig_shape)", "chunk_type": "class", "name": "BaseTensor", "file_path": "ultralytics\\ultralytics\\engine\\results.py", "start_line": 21, "end_line": 187, "start_col": 0, "end_col": 62, "parent_name": null, "docstring": "Base tensor class with additional methods for easy manipulation and device handling.\n\nThis class provides a foundation for tensor-like objects with device management capabilities,\nsupporting both PyTorch tensors and NumPy arrays. It includes methods for moving data between\ndevices and converting between tensor types.\n\nAttributes:\n data (torch.Tensor | np.ndarray): Prediction data such as bounding boxes, masks, or keypoints.\n orig_shape (Tuple[int, int]): Original shape of the image, typically in the format (height, width).\n\nMethods:\n cpu: Return a copy of the tensor stored in CPU memory.\n numpy: Return a copy of the tensor as a numpy array.\n cuda: Move the tensor to GPU memory, returning a new instance if necessary.\n to: Return a copy of the tensor with the specified device and dtype.\n\nExamples:\n >>> import torch\n >>> data = torch.tensor([[1, 2, 3], [4, 5, 6]])\n >>> orig_shape = (720, 1280)\n >>> base_tensor = BaseTensor(data, orig_shape)\n >>> cpu_tensor = base_tensor.cpu()\n >>> numpy_array = base_tensor.numpy()\n >>> gpu_tensor = base_tensor.cuda()", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "copy.deepcopy", "functools.lru_cache", "pathlib.Path", "typing.Any", "typing.Dict", "typing.List", "typing.Optional", "typing.Tuple", "typing.Union", "numpy", "torch", "ultralytics.data.augment.LetterBox", "ultralytics.utils.LOGGER", "ultralytics.utils.DataExportMixin", "ultralytics.utils.SimpleClass", "ultralytics.utils.ops", "ultralytics.utils.plotting.Annotator", "ultralytics.utils.plotting.colors", "ultralytics.utils.plotting.save_one_box", "SimpleClass" ], "chunk_id": "class_BaseTensor_25347712" }, { "content": "class Results(SimpleClass, DataExportMixin):\n \"\"\"\n A class for storing and manipulating inference results.\n\n This class provides comprehensive functionality for handling inference results from various\n Ultralytics models, including detection, segmentation, classification, and pose estimation.\n It supports visualization, data export, and various coordinate transformations.\n\n Attributes:\n orig_img (np.ndarray): The original image as a numpy array.\n orig_shape (Tuple[int, int]): Original image shape in (height, width) format.\n boxes (Boxes | None): Detected bounding boxes.\n masks (Masks | None): Segmentation masks.\n probs (Probs | None): Classification probabilities.\n keypoints (Keypoints | None): Detected keypoints.\n obb (OBB | None): Oriented bounding boxes.\n speed (dict): Dictionary containing inference speed information.\n names (dict): Dictionary mapping class indices to class names.\n path (str): Path to the input image file.\n save_dir (str | None): Directory to save results.\n\n Methods:\n update: Update the Results object with new detection data.\n cpu: Return a copy of the Results object with all tensors moved to CPU memory.\n numpy: Convert all tensors in the Results object to numpy arrays.\n cuda: Move all tensors in the Results object to GPU memory.\n to: Move all tensors to the specified device and dtype.\n new: Create a new Results object with the same image, path, names, and speed attributes.\n plot: Plot detection results on an input RGB image.\n show: Display the image with annotated inference results.\n save: Save annotated inference results image to file.\n verbose: Return a log string for each task in the results.\n save_txt: Save detection results to a text file.\n save_crop: Save cropped detection images to specified directory.\n summary: Convert inference results to a summarized dictionary.\n to_df: Convert detection results to a Pandas Dataframe.\n to_json: Convert detection results to JSON format.\n to_csv: Convert detection results to a CSV format.\n to_xml: Convert detection results to XML format.\n to_html: Convert detection results to HTML format.\n to_sql: Convert detection results to an SQL-compatible format.\n\n Examples:\n >>> results = model(\"path/to/image.jpg\")\n >>> result = results[0] # Get the first result\n >>> boxes = result.boxes # Get the boxes for the first result\n >>> masks = result.masks # Get the masks for the first result\n >>> for result in results:\n >>> result.plot() # Plot detection results\n \"\"\"\n\n def __init__(\n self,\n orig_img: np.ndarray,\n path: str,\n names: Dict[int, str],\n boxes: Optional[torch.Tensor] = None,\n masks: Optional[torch.Tensor] = None,\n probs: Optional[torch.Tensor] = None,\n keypoints: Optional[torch.Tensor] = None,\n obb: Optional[torch.Tensor] = None,\n speed: Optional[Dict[str, float]] = None,\n ) -> None:\n \"\"\"\n Initialize the Results class for storing and manipulating inference results.\n\n Args:\n orig_img (np.ndarray): The original image as a numpy array.\n path (str): The path to the image file.\n names (dict): A dictionary of class names.\n boxes (torch.Tensor | None): A 2D tensor of bounding box coordinates for each detection.\n masks (torch.Tensor | None): A 3D tensor of detection masks, where each mask is a binary image.\n probs (torch.Tensor | None): A 1D tensor of probabilities of each class for classification task.\n keypoints (torch.Tensor | None): A 2D tensor of keypoint coordinates for each detection.\n obb (torch.Tensor | None): A 2D tensor of oriented bounding box coordinates for each detection.\n speed (Dict | None): A dictionary containing preprocess, inference, and postprocess speeds (ms/image).\n\n Examples:\n >>> results = model(\"path/to/image.jpg\")\n >>> result = results[0] # Get the first result\n >>> boxes = result.boxes # Get the boxes for the first result\n >>> masks = result.masks # Get the masks for the first result\n\n Notes:\n For the default pose model, keypoint indices for human body pose estimation are:\n 0: Nose, 1: Left Eye, 2: Right Eye, 3: Left Ear, 4: Right Ear\n 5: Left Shoulder, 6: Right Shoulder, 7: Left Elbow, 8: Right Elbow\n 9: Left Wrist, 10: Right Wrist, 11: Left Hip, 12: Right Hip\n 13: Left Knee, 14: Right Knee, 15: Left Ankle, 16: Right Ankle\n \"\"\"\n self.orig_img = orig_img\n self.orig_shape = orig_img.shape[:2]\n self.boxes = Boxes(boxes, self.orig_shape) if boxes is not None else None # native size boxes\n self.masks = Masks(masks, self.orig_shape) if masks is not None else None # native size or imgsz masks\n self.probs = Probs(probs) if probs is not None else None\n self.keypoints = Keypoints(keypoints, self.orig_shape) if keypoints is not None else None\n self.obb = OBB(obb, self.orig_shape) if obb is not None else None\n self.speed = speed if speed is not None else {\"preprocess\": None, \"inference\": None, \"postprocess\": None}\n self.names = names\n self.path = path\n self.save_dir = None\n self._keys = \"boxes\", \"masks\", \"probs\", \"keypoints\", \"obb\"\n\n def __getitem__(self, idx):\n \"\"\"\n Return a Results object for a specific index of inference results.\n\n Args:\n idx (int | slice): Index or slice to retrieve from the Results object.\n\n Returns:\n (Results): A new Results object containing the specified subset of inference results.\n\n Examples:\n >>> results = model(\"path/to/image.jpg\") # Perform inference\n >>> single_result = results[0] # Get the first result\n >>> subset_results = results[1:4] # Get a slice of results\n \"\"\"\n return self._apply(\"__getitem__\", idx)\n\n def __len__(self) -> int:\n \"\"\"\n Return the number of detections in the Results object.\n\n Returns:\n (int): The number of detections, determined by the length of the first non-empty\n attribute in (masks, probs, keypoints, or obb).\n\n Examples:\n >>> results = Results(orig_img, path, names, boxes=torch.rand(5, 4))\n >>> len(results)\n 5\n \"\"\"\n for k in self._keys:\n v = getattr(self, k)\n if v is not None:\n return len(v)\n\n def update(\n self,\n boxes: Optional[torch.Tensor] = None,\n masks: Optional[torch.Tensor] = None,\n probs: Optional[torch.Tensor] = None,\n obb: Optional[torch.Tensor] = None,\n keypoints: Optional[torch.Tensor] = None,\n ):\n \"\"\"\n Update the Results object with new detection data.\n\n This method allows updating the boxes, masks, probabilities, and oriented bounding boxes (OBB) of the\n Results object. It ensures that boxes are clipped to the original image shape.\n\n Args:\n boxes (torch.Tensor | None): A tensor of shape (N, 6) containing bounding box coordinates and\n confidence scores. The format is (x1, y1, x2, y2, conf, class).\n masks (torch.Tensor | None): A tensor of shape (N, H, W) containing segmentation masks.\n probs (torch.Tensor | None): A tensor of shape (num_classes,) containing class probabilities.\n obb (torch.Tensor | None): A tensor of shape (N, 5) containing oriented bounding box coordinates.\n keypoints (torch.Tensor | None): A tensor of shape (N, 17, 3) containing keypoints.\n\n Examples:\n >>> results = model(\"image.jpg\")\n >>> new_boxes = torch.tensor([[100, 100, 200, 200, 0.9, 0]])\n >>> results[0].update(boxes=new_boxes)\n \"\"\"\n if boxes is not None:\n self.boxes = Boxes(ops.clip_boxes(boxes, self.orig_shape), self.orig_shape)\n if masks is not None:\n self.masks = Masks(masks, self.orig_shape)\n if probs is not None:\n self.probs = probs\n if obb is not None:\n self.obb = OBB(obb, self.orig_shape)\n if keypoints is not None:\n self.keypoints = Keypoints(keypoints, self.orig_shape)\n\n def _apply(self, fn: str, *args, **kwargs):\n \"\"\"\n Apply a function to all non-empty attributes and return a new Results object with modified attributes.\n\n This method is internally called by methods like .to(), .cuda(), .cpu(), etc.\n\n Args:\n fn (str): The name of the function to apply.\n *args (Any): Variable length argument list to pass to the function.\n **kwargs (Any): Arbitrary keyword arguments to pass to the function.\n\n Returns:\n (Results): A new Results object with attributes modified by the applied function.\n\n Examples:\n >>> results = model(\"path/to/image.jpg\")\n >>> for result in results:\n ... result_cuda = result.cuda()\n ... result_cpu = result.cpu()\n \"\"\"\n r = self.new()\n for k in self._keys:\n v = getattr(self, k)\n if v is not None:\n setattr(r, k, getattr(v, fn)(*args, **kwargs))\n return r\n\n def cpu(self):\n \"\"\"\n Return a copy of the Results object with all its tensors moved to CPU memory.\n\n This method creates a new Results object with all tensor attributes (boxes, masks, probs, keypoints, obb)\n transferred to CPU memory. It's useful for moving data from GPU to CPU for further processing or saving.\n\n Returns:\n (Results): A new Results object with all tensor attributes on CPU memory.\n\n Examples:\n >>> results = model(\"path/to/image.jpg\") # Perform inference\n >>> cpu_result = results[0].cpu() # Move the first result to CPU\n >>> print(cpu_result.boxes.device) # Output: cpu\n \"\"\"\n return self._apply(\"cpu\")\n\n def numpy(self):\n \"\"\"\n Convert all tensors in the Results object to numpy arrays.\n\n Returns:\n (Results): A new Results object with all tensors converted to numpy arrays.\n\n Examples:\n >>> results = model(\"path/to/image.jpg\")\n >>> numpy_result = results[0].numpy()\n >>> type(numpy_result.boxes.data)\n \n\n Notes:\n This method creates a new Results object, leaving the original unchanged. It's useful for\n interoperability with numpy-based libraries or when CPU-based operations are required.\n \"\"\"\n return self._apply(\"numpy\")\n\n def cuda(self):\n \"\"\"\n Move all tensors in the Results object to GPU memory.\n\n Returns:\n (Results): A new Results object with all tensors moved to CUDA device.\n\n Examples:\n >>> results = model(\"path/to/image.jpg\")\n >>> cuda_results = results[0].cuda() # Move first result to GPU\n >>> for result in results:\n ... result_cuda = result.cuda() # Move each result to GPU\n \"\"\"\n return self._apply(\"cuda\")\n\n def to(self, *args, **kwargs):\n \"\"\"\n Move all tensors in the Results object to the specified device and dtype.\n\n Args:\n *args (Any): Variable length argument list to be passed to torch.Tensor.to().\n **kwargs (Any): Arbitrary keyword arguments to be passed to torch.Tensor.to().\n\n Returns:\n (Results): A new Results object with all tensors moved to the specified device and dtype.\n\n Examples:\n >>> results = model(\"path/to/image.jpg\")\n >>> result_cuda = results[0].to(\"cuda\") # Move first result to GPU\n >>> result_cpu = results[0].to(\"cpu\") # Move first result to CPU\n >>> result_half = results[0].to(dtype=torch.float16) # Convert first result to half precision\n \"\"\"\n return self._apply(\"to\", *args, **kwargs)\n\n def new(self):\n \"\"\"\n Create a new Results object with the same image, path, names, and speed attributes.\n\n Returns:\n (Results): A new Results object with copied attributes from the original instance.\n\n Examples:\n >>> results = model(\"path/to/image.jpg\")\n >>> new_result = results[0].new()\n \"\"\"\n return Results(orig_img=self.orig_img, path=self.path, names=self.names, speed=self.speed)\n\n def plot(\n self,\n conf: bool = True,\n line_width: Optional[float] = None,\n font_size: Optional[float] = None,\n font: str = \"Arial.ttf\",\n pil: bool = False,\n img: Optional[np.ndarray] = None,\n im_gpu: Optional[torch.Tensor] = None,\n kpt_radius: int = 5,\n kpt_line: bool = True,\n labels: bool = True,\n boxes: bool = True,\n masks: bool = True,\n probs: bool = True,\n show: bool = False,\n save: bool = False,\n filename: Optional[str] = None,\n color_mode: str = \"class\",\n txt_color: Tuple[int, int, int] = (255, 255, 255),\n ) -> np.ndarray:\n \"\"\"\n Plot detection results on an input RGB image.\n\n Args:\n conf (bool): Whether to plot detection confidence scores.\n line_width (float | None): Line width of bounding boxes. If None, scaled to image size.\n font_size (float | None): Font size for text. If None, scaled to image size.\n font (str): Font to use for text.\n pil (bool): Whether to return the image as a PIL Image.\n img (np.ndarray | None): Image to plot on. If None, uses original image.\n im_gpu (torch.Tensor | None): Normalized image on GPU for faster mask plotting.\n kpt_radius (int): Radius of drawn keypoints.\n kpt_line (bool): Whether to draw lines connecting keypoints.\n labels (bool): Whether to plot labels of bounding boxes.\n boxes (bool): Whether to plot bounding boxes.\n masks (bool): Whether to plot masks.\n probs (bool): Whether to plot classification probabilities.\n show (bool): Whether to display the annotated image.\n save (bool): Whether to save the annotated image.\n filename (str | None): Filename to save image if save is True.\n color_mode (str): Specify the color mode, e.g., 'instance' or 'class'.\n txt_color (tuple[int, int, int]): Specify the RGB text color for classification task.\n\n Returns:\n (np.ndarray): Annotated image as a numpy array.\n\n Examples:\n >>> results = model(\"image.jpg\")\n >>> for result in results:\n >>> im = result.plot()\n >>> im.show()\n \"\"\"\n assert color_mode in {\"instance\", \"class\"}, f\"Expected color_mode='instance' or 'class', not {color_mode}.\"\n if img is None and isinstance(self.orig_img, torch.Tensor):\n img = (self.orig_img[0].detach().permute(1, 2, 0).contiguous() * 255).to(torch.uint8).cpu().numpy()\n\n names = self.names\n is_obb = self.obb is not None\n pred_boxes, show_boxes = self.obb if is_obb else self.boxes, boxes\n pred_masks, show_masks = self.masks, masks\n pred_probs, show_probs = self.probs, probs\n annotator = Annotator(\n deepcopy(self.orig_img if img is None else img),\n line_width,\n font_size,\n font,\n pil or (pred_probs is not None and show_probs), # Classify tasks default to pil=True\n example=names,\n )\n\n # Plot Segment results\n if pred_masks and show_masks:\n if im_gpu is None:\n img = LetterBox(pred_masks.shape[1:])(image=annotator.result())\n im_gpu = (\n torch.as_tensor(img, dtype=torch.float16, device=pred_masks.data.device)\n .permute(2, 0, 1)\n .flip(0)\n .contiguous()\n / 255\n )\n idx = (\n pred_boxes.id\n if pred_boxes.is_track and color_mode == \"instance\"\n else pred_boxes.cls\n if pred_boxes and color_mode == \"class\"\n else reversed(range(len(pred_masks)))\n )\n annotator.masks(pred_masks.data, colors=[colors(x, True) for x in idx], im_gpu=im_gpu)\n\n # Plot Detect results\n if pred_boxes is not None and show_boxes:\n for i, d in enumerate(reversed(pred_boxes)):\n c, d_conf, id = int(d.cls), float(d.conf) if conf else None, int(d.id.item()) if d.is_track else None\n name = (\"\" if id is None else f\"id:{id} \") + names[c]\n label = (f\"{name} {d_conf:.2f}\" if conf else name) if labels else None\n box = d.xyxyxyxy.squeeze() if is_obb else d.xyxy.squeeze()\n annotator.box_label(\n box,\n label,\n color=colors(\n c\n if color_mode == \"class\"\n else id\n if id is not None\n else i\n if color_mode == \"instance\"\n else None,\n True,\n ),\n )\n\n # Plot Classify results\n if pred_probs is not None and show_probs:\n text = \"\\n\".join(f\"{names[j] if names else j} {pred_probs.data[j]:.2f}\" for j in pred_probs.top5)\n x = round(self.orig_shape[0] * 0.03)\n annotator.text([x, x], text, txt_color=txt_color, box_color=(64, 64, 64, 128)) # RGBA box\n\n # Plot Pose results\n if self.keypoints is not None:\n for i, k in enumerate(reversed(self.keypoints.data)):\n annotator.kpts(\n k,\n self.orig_shape,\n radius=kpt_radius,\n kpt_line=kpt_line,\n kpt_color=colors(i, True) if color_mode == \"instance\" else None,\n )\n\n # Show results\n if show:\n annotator.show(self.path)\n\n # Save results\n if save:\n annotator.save(filename or f\"results_{Path(self.path).name}\")\n\n return annotator.im if pil else annotator.result()\n\n def show(self, *args, **kwargs):\n \"\"\"\n Display the image with annotated inference results.\n\n This method plots the detection results on the original image and displays it. It's a convenient way to\n visualize the model's predictions directly.\n\n Args:\n *args (Any): Variable length argument list to be passed to the `plot()` method.\n **kwargs (Any): Arbitrary keyword arguments to be passed to the `plot()` method.\n\n Examples:\n >>> results = model(\"path/to/image.jpg\")\n >>> results[0].show() # Display the first result\n >>> for result in results:\n >>> result.show() # Display all results\n \"\"\"\n self.plot(show=True, *args, **kwargs)\n\n def save(self, filename: Optional[str] = None, *args, **kwargs) -> str:\n \"\"\"\n Save annotated inference results image to file.\n\n This method plots the detection results on the original image and saves the annotated image to a file. It\n utilizes the `plot` method to generate the annotated image and then saves it to the specified filename.\n\n Args:\n filename (str | Path | None): The filename to save the annotated image. If None, a default filename\n is generated based on the original image path.\n *args (Any): Variable length argument list to be passed to the `plot` method.\n **kwargs (Any): Arbitrary keyword arguments to be passed to the `plot` method.\n\n Returns:\n (str): The filename where the image was saved.\n\n Examples:\n >>> results = model(\"path/to/image.jpg\")\n >>> for result in results:\n >>> result.save(\"annotated_image.jpg\")\n >>> # Or with custom plot arguments\n >>> for result in results:\n >>> result.save(\"annotated_image.jpg\", conf=False, line_width=2)\n \"\"\"\n if not filename:\n filename = f\"results_{Path(self.path).name}\"\n self.plot(save=True, filename=filename, *args, **kwargs)\n return filename\n\n def verbose(self) -> str:\n \"\"\"\n Return a log string for each task in the results, detailing detection and classification outcomes.\n\n This method generates a human-readable string summarizing the detection and classification results. It includes\n the number of detections for each class and the top probabilities for classification tasks.\n\n Returns:\n (str): A formatted string containing a summary of the results. For detection tasks, it includes the\n number of detections per class. For classification tasks, it includes the top 5 class probabilities.\n\n Examples:\n >>> results = model(\"path/to/image.jpg\")\n >>> for result in results:\n >>> print(result.verbose())\n 2 persons, 1 car, 3 traffic lights,\n dog 0.92, cat 0.78, horse 0.64,\n\n Notes:\n - If there are no detections, the method returns \"(no detections), \" for detection tasks.\n - For classification tasks, it returns the top 5 class probabilities and their corresponding class names.\n - The returned string is comma-separated and ends with a comma and a space.\n \"\"\"\n probs = self.probs\n if len(self) == 0:\n return \"\" if probs is not None else \"(no detections), \"\n if probs is not None:\n return f\"{', '.join(f'{self.names[j]} {probs.data[j]:.2f}' for j in probs.top5)}, \"\n if boxes := self.boxes:\n counts = boxes.cls.int().bincount()\n return \"\".join(f\"{n} {self.names[i]}{'s' * (n > 1)}, \" for i, n in enumerate(counts) if n > 0)\n\n def save_txt(self, txt_file: Union[str, Path], save_conf: bool = False) -> str:\n \"\"\"\n Save detection results to a text file.\n\n Args:\n txt_file (str | Path): Path to the output text file.\n save_conf (bool): Whether to include confidence scores in the output.\n\n Returns:\n (str): Path to the saved text file.\n\n Examples:\n >>> from ultralytics import YOLO\n >>> model = YOLO(\"yolo11n.pt\")\n >>> results = model(\"path/to/image.jpg\")\n >>> for result in results:\n >>> result.save_txt(\"output.txt\")\n\n Notes:\n - The file will contain one line per detection or classification with the following structure:\n - For detections: `class confidence x_center y_center width height`\n - For classifications: `confidence class_name`\n - For masks and keypoints, the specific formats will vary accordingly.\n - The function will create the output directory if it does not exist.\n - If save_conf is False, the confidence scores will be excluded from the output.\n - Existing contents of the file will not be overwritten; new results will be appended.\n \"\"\"\n is_obb = self.obb is not None\n boxes = self.obb if is_obb else self.boxes\n masks = self.masks\n probs = self.probs\n kpts = self.keypoints\n texts = []\n if probs is not None:\n # Classify\n [texts.append(f\"{probs.data[j]:.2f} {self.names[j]}\") for j in probs.top5]\n elif boxes:\n # Detect/segment/pose\n for j, d in enumerate(boxes):\n c, conf, id = int(d.cls), float(d.conf), int(d.id.item()) if d.is_track else None\n line = (c, *(d.xyxyxyxyn.view(-1) if is_obb else d.xywhn.view(-1)))\n if masks:\n seg = masks[j].xyn[0].copy().reshape(-1) # reversed mask.xyn, (n,2) to (n*2)\n line = (c, *seg)\n if kpts is not None:\n kpt = torch.cat((kpts[j].xyn, kpts[j].conf[..., None]), 2) if kpts[j].has_visible else kpts[j].xyn\n line += (*kpt.reshape(-1).tolist(),)\n line += (conf,) * save_conf + (() if id is None else (id,))\n texts.append((\"%g \" * len(line)).rstrip() % line)\n\n if texts:\n Path(txt_file).parent.mkdir(parents=True, exist_ok=True) # make directory\n with open(txt_file, \"a\", encoding=\"utf-8\") as f:\n f.writelines(text + \"\\n\" for text in texts)\n\n return str(txt_file)\n\n def save_crop(self, save_dir: Union[str, Path], file_name: Union[str, Path] = Path(\"im.jpg\")):\n \"\"\"\n Save cropped detection images to specified directory.\n\n This method saves cropped images of detected objects to a specified directory. Each crop is saved in a\n subdirectory named after the object's class, with the filename based on the input file_name.\n\n Args:\n save_dir (str | Path): Directory path where cropped images will be saved.\n file_name (str | Path): Base filename for the saved cropped images.\n\n Notes:\n - This method does not support Classify or Oriented Bounding Box (OBB) tasks.\n - Crops are saved as 'save_dir/class_name/file_name.jpg'.\n - The method will create necessary subdirectories if they don't exist.\n - Original image is copied before cropping to avoid modifying the original.\n\n Examples:\n >>> results = model(\"path/to/image.jpg\")\n >>> for result in results:\n >>> result.save_crop(save_dir=\"path/to/crops\", file_name=\"detection\")\n \"\"\"\n if self.probs is not None:\n LOGGER.warning(\"Classify task do not support `save_crop`.\")\n return\n if self.obb is not None:\n LOGGER.warning(\"OBB task do not support `save_crop`.\")\n return\n for d in self.boxes:\n save_one_box(\n d.xyxy,\n self.orig_img.copy(),\n file=Path(save_dir) / self.names[int(d.cls)] / Path(file_name).with_suffix(\".jpg\"),\n BGR=True,\n )\n\n def summary(self, normalize: bool = False, decimals: int = 5) -> List[Dict[str, Any]]:\n \"\"\"\n Convert inference results to a summarized dictionary with optional normalization for box coordinates.\n\n This method creates a list of detection dictionaries, each containing information about a single\n detection or classification result. For classification tasks, it returns the top class and its\n confidence. For detection tasks, it includes class information, bounding box coordinates, and\n optionally mask segments and keypoints.\n\n Args:\n normalize (bool): Whether to normalize bounding box coordinates by image dimensions.\n decimals (int): Number of decimal places to round the output values to.\n\n Returns:\n (List[Dict[str, Any]]): A list of dictionaries, each containing summarized information for a single detection\n or classification result. The structure of each dictionary varies based on the task type\n (classification or detection) and available information (boxes, masks, keypoints).\n\n Examples:\n >>> results = model(\"image.jpg\")\n >>> for result in results:\n >>> summary = result.summary()\n >>> print(summary)\n \"\"\"\n # Create list of detection dictionaries\n results = []\n if self.probs is not None:\n class_id = self.probs.top1\n results.append(\n {\n \"name\": self.names[class_id],\n \"class\": class_id,\n \"confidence\": round(self.probs.top1conf.item(), decimals),\n }\n )\n return results\n\n is_obb = self.obb is not None\n data = self.obb if is_obb else self.boxes\n h, w = self.orig_shape if normalize else (1, 1)\n for i, row in enumerate(data): # xyxy, track_id if tracking, conf, class_id\n class_id, conf = int(row.cls), round(row.conf.item(), decimals)\n box = (row.xyxyxyxy if is_obb else row.xyxy).squeeze().reshape(-1, 2).tolist()\n xy = {}\n for j, b in enumerate(box):\n xy[f\"x{j + 1}\"] = round(b[0] / w, decimals)\n xy[f\"y{j + 1}\"] = round(b[1] / h, decimals)\n result = {\"name\": self.names[class_id], \"class\": class_id, \"confidence\": conf, \"box\": xy}\n if data.is_track:\n result[\"track_id\"] = int(row.id.item()) # track ID\n if self.masks:\n result[\"segments\"] = {\n \"x\": (self.masks.xy[i][:, 0] / w).round(decimals).tolist(),\n \"y\": (self.masks.xy[i][:, 1] / h).round(decimals).tolist(),\n }\n if self.keypoints is not None:\n x, y, visible = self.keypoints[i].data[0].cpu().unbind(dim=1) # torch Tensor\n result[\"keypoints\"] = {\n \"x\": (x / w).numpy().round(decimals).tolist(), # decimals named argument required\n \"y\": (y / h).numpy().round(decimals).tolist(),\n \"visible\": visible.numpy().round(decimals).tolist(),\n }\n results.append(result)\n\n return results", "chunk_type": "class", "name": "Results", "file_path": "ultralytics\\ultralytics\\engine\\results.py", "start_line": 190, "end_line": 853, "start_col": 0, "end_col": 22, "parent_name": null, "docstring": "A class for storing and manipulating inference results.\n\nThis class provides comprehensive functionality for handling inference results from various\nUltralytics models, including detection, segmentation, classification, and pose estimation.\nIt supports visualization, data export, and various coordinate transformations.\n\nAttributes:\n orig_img (np.ndarray): The original image as a numpy array.\n orig_shape (Tuple[int, int]): Original image shape in (height, width) format.\n boxes (Boxes | None): Detected bounding boxes.\n masks (Masks | None): Segmentation masks.\n probs (Probs | None): Classification probabilities.\n keypoints (Keypoints | None): Detected keypoints.\n obb (OBB | None): Oriented bounding boxes.\n speed (dict): Dictionary containing inference speed information.\n names (dict): Dictionary mapping class indices to class names.\n path (str): Path to the input image file.\n save_dir (str | None): Directory to save results.\n\nMethods:\n update: Update the Results object with new detection data.\n cpu: Return a copy of the Results object with all tensors moved to CPU memory.\n numpy: Convert all tensors in the Results object to numpy arrays.\n cuda: Move all tensors in the Results object to GPU memory.\n to: Move all tensors to the specified device and dtype.\n new: Create a new Results object with the same image, path, names, and speed attributes.\n plot: Plot detection results on an input RGB image.\n show: Display the image with annotated inference results.\n save: Save annotated inference results image to file.\n verbose: Return a log string for each task in the results.\n save_txt: Save detection results to a text file.\n save_crop: Save cropped detection images to specified directory.\n summary: Convert inference results to a summarized dictionary.\n to_df: Convert detection results to a Pandas Dataframe.\n to_json: Convert detection results to JSON format.\n to_csv: Convert detection results to a CSV format.\n to_xml: Convert detection results to XML format.\n to_html: Convert detection results to HTML format.\n to_sql: Convert detection results to an SQL-compatible format.\n\nExamples:\n >>> results = model(\"path/to/image.jpg\")\n >>> result = results[0] # Get the first result\n >>> boxes = result.boxes # Get the boxes for the first result\n >>> masks = result.masks # Get the masks for the first result\n >>> for result in results:\n >>> result.plot() # Plot detection results", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "copy.deepcopy", "functools.lru_cache", "pathlib.Path", "typing.Any", "typing.Dict", "typing.List", "typing.Optional", "typing.Tuple", "typing.Union", "numpy", "torch", "ultralytics.data.augment.LetterBox", "ultralytics.utils.LOGGER", "ultralytics.utils.DataExportMixin", "ultralytics.utils.SimpleClass", "ultralytics.utils.ops", "ultralytics.utils.plotting.Annotator", "ultralytics.utils.plotting.colors", "ultralytics.utils.plotting.save_one_box", "SimpleClass", "DataExportMixin" ], "chunk_id": "class_Results_5e340baf" }, { "content": "class Boxes(BaseTensor):\n \"\"\"\n A class for managing and manipulating detection boxes.\n\n This class provides comprehensive functionality for handling detection boxes, including their coordinates,\n confidence scores, class labels, and optional tracking IDs. It supports various box formats and offers\n methods for easy manipulation and conversion between different coordinate systems.\n\n Attributes:\n data (torch.Tensor | np.ndarray): The raw tensor containing detection boxes and associated data.\n orig_shape (Tuple[int, int]): The original image dimensions (height, width).\n is_track (bool): Indicates whether tracking IDs are included in the box data.\n xyxy (torch.Tensor | np.ndarray): Boxes in [x1, y1, x2, y2] format.\n conf (torch.Tensor | np.ndarray): Confidence scores for each box.\n cls (torch.Tensor | np.ndarray): Class labels for each box.\n id (torch.Tensor | None): Tracking IDs for each box (if available).\n xywh (torch.Tensor | np.ndarray): Boxes in [x, y, width, height] format.\n xyxyn (torch.Tensor | np.ndarray): Normalized [x1, y1, x2, y2] boxes relative to orig_shape.\n xywhn (torch.Tensor | np.ndarray): Normalized [x, y, width, height] boxes relative to orig_shape.\n\n Methods:\n cpu: Return a copy of the object with all tensors on CPU memory.\n numpy: Return a copy of the object with all tensors as numpy arrays.\n cuda: Return a copy of the object with all tensors on GPU memory.\n to: Return a copy of the object with tensors on specified device and dtype.\n\n Examples:\n >>> import torch\n >>> boxes_data = torch.tensor([[100, 50, 150, 100, 0.9, 0], [200, 150, 300, 250, 0.8, 1]])\n >>> orig_shape = (480, 640) # height, width\n >>> boxes = Boxes(boxes_data, orig_shape)\n >>> print(boxes.xyxy)\n >>> print(boxes.conf)\n >>> print(boxes.cls)\n >>> print(boxes.xywhn)\n \"\"\"\n\n def __init__(self, boxes: Union[torch.Tensor, np.ndarray], orig_shape: Tuple[int, int]) -> None:\n \"\"\"\n Initialize the Boxes class with detection box data and the original image shape.\n\n This class manages detection boxes, providing easy access and manipulation of box coordinates,\n confidence scores, class identifiers, and optional tracking IDs. It supports multiple formats\n for box coordinates, including both absolute and normalized forms.\n\n Args:\n boxes (torch.Tensor | np.ndarray): A tensor or numpy array with detection boxes of shape\n (num_boxes, 6) or (num_boxes, 7). Columns should contain\n [x1, y1, x2, y2, confidence, class, (optional) track_id].\n orig_shape (Tuple[int, int]): The original image shape as (height, width). Used for normalization.\n\n Attributes:\n data (torch.Tensor): The raw tensor containing detection boxes and their associated data.\n orig_shape (Tuple[int, int]): The original image size, used for normalization.\n is_track (bool): Indicates whether tracking IDs are included in the box data.\n\n Examples:\n >>> import torch\n >>> boxes = torch.tensor([[100, 50, 150, 100, 0.9, 0]])\n >>> orig_shape = (480, 640)\n >>> detection_boxes = Boxes(boxes, orig_shape)\n >>> print(detection_boxes.xyxy)\n tensor([[100., 50., 150., 100.]])\n \"\"\"\n if boxes.ndim == 1:\n boxes = boxes[None, :]\n n = boxes.shape[-1]\n assert n in {6, 7}, f\"expected 6 or 7 values but got {n}\" # xyxy, track_id, conf, cls\n super().__init__(boxes, orig_shape)\n self.is_track = n == 7\n self.orig_shape = orig_shape\n\n @property\n def xyxy(self) -> Union[torch.Tensor, np.ndarray]:\n \"\"\"\n Return bounding boxes in [x1, y1, x2, y2] format.\n\n Returns:\n (torch.Tensor | np.ndarray): A tensor or numpy array of shape (n, 4) containing bounding box\n coordinates in [x1, y1, x2, y2] format, where n is the number of boxes.\n\n Examples:\n >>> results = model(\"image.jpg\")\n >>> boxes = results[0].boxes\n >>> xyxy = boxes.xyxy\n >>> print(xyxy)\n \"\"\"\n return self.data[:, :4]\n\n @property\n def conf(self) -> Union[torch.Tensor, np.ndarray]:\n \"\"\"\n Return the confidence scores for each detection box.\n\n Returns:\n (torch.Tensor | np.ndarray): A 1D tensor or array containing confidence scores for each detection,\n with shape (N,) where N is the number of detections.\n\n Examples:\n >>> boxes = Boxes(torch.tensor([[10, 20, 30, 40, 0.9, 0]]), orig_shape=(100, 100))\n >>> conf_scores = boxes.conf\n >>> print(conf_scores)\n tensor([0.9000])\n \"\"\"\n return self.data[:, -2]\n\n @property\n def cls(self) -> Union[torch.Tensor, np.ndarray]:\n \"\"\"\n Return the class ID tensor representing category predictions for each bounding box.\n\n Returns:\n (torch.Tensor | np.ndarray): A tensor or numpy array containing the class IDs for each detection box.\n The shape is (N,), where N is the number of boxes.\n\n Examples:\n >>> results = model(\"image.jpg\")\n >>> boxes = results[0].boxes\n >>> class_ids = boxes.cls\n >>> print(class_ids) # tensor([0., 2., 1.])\n \"\"\"\n return self.data[:, -1]\n\n @property\n def id(self) -> Optional[Union[torch.Tensor, np.ndarray]]:\n \"\"\"\n Return the tracking IDs for each detection box if available.\n\n Returns:\n (torch.Tensor | None): A tensor containing tracking IDs for each box if tracking is enabled,\n otherwise None. Shape is (N,) where N is the number of boxes.\n\n Examples:\n >>> results = model.track(\"path/to/video.mp4\")\n >>> for result in results:\n ... boxes = result.boxes\n ... if boxes.is_track:\n ... track_ids = boxes.id\n ... print(f\"Tracking IDs: {track_ids}\")\n ... else:\n ... print(\"Tracking is not enabled for these boxes.\")\n\n Notes:\n - This property is only available when tracking is enabled (i.e., when `is_track` is True).\n - The tracking IDs are typically used to associate detections across multiple frames in video analysis.\n \"\"\"\n return self.data[:, -3] if self.is_track else None\n\n @property\n @lru_cache(maxsize=2)\n def xywh(self) -> Union[torch.Tensor, np.ndarray]:\n \"\"\"\n Convert bounding boxes from [x1, y1, x2, y2] format to [x, y, width, height] format.\n\n Returns:\n (torch.Tensor | np.ndarray): Boxes in [x_center, y_center, width, height] format, where x_center,\n y_center are the coordinates of the center point of the bounding box, width, height are the\n dimensions of the bounding box and the shape of the returned tensor is (N, 4), where N is the\n number of boxes.\n\n Examples:\n >>> boxes = Boxes(torch.tensor([[100, 50, 150, 100], [200, 150, 300, 250]]), orig_shape=(480, 640))\n >>> xywh = boxes.xywh\n >>> print(xywh)\n tensor([[100.0000, 50.0000, 50.0000, 50.0000],\n [200.0000, 150.0000, 100.0000, 100.0000]])\n \"\"\"\n return ops.xyxy2xywh(self.xyxy)\n\n @property\n @lru_cache(maxsize=2)\n def xyxyn(self) -> Union[torch.Tensor, np.ndarray]:\n \"\"\"\n Return normalized bounding box coordinates relative to the original image size.\n\n This property calculates and returns the bounding box coordinates in [x1, y1, x2, y2] format,\n normalized to the range [0, 1] based on the original image dimensions.\n\n Returns:\n (torch.Tensor | np.ndarray): Normalized bounding box coordinates with shape (N, 4), where N is\n the number of boxes. Each row contains [x1, y1, x2, y2] values normalized to [0, 1].\n\n Examples:\n >>> boxes = Boxes(torch.tensor([[100, 50, 300, 400, 0.9, 0]]), orig_shape=(480, 640))\n >>> normalized = boxes.xyxyn\n >>> print(normalized)\n tensor([[0.1562, 0.1042, 0.4688, 0.8333]])\n \"\"\"\n xyxy = self.xyxy.clone() if isinstance(self.xyxy, torch.Tensor) else np.copy(self.xyxy)\n xyxy[..., [0, 2]] /= self.orig_shape[1]\n xyxy[..., [1, 3]] /= self.orig_shape[0]\n return xyxy\n\n @property\n @lru_cache(maxsize=2)\n def xywhn(self) -> Union[torch.Tensor, np.ndarray]:\n \"\"\"\n Return normalized bounding boxes in [x, y, width, height] format.\n\n This property calculates and returns the normalized bounding box coordinates in the format\n [x_center, y_center, width, height], where all values are relative to the original image dimensions.\n\n Returns:\n (torch.Tensor | np.ndarray): Normalized bounding boxes with shape (N, 4), where N is the\n number of boxes. Each row contains [x_center, y_center, width, height] values normalized\n to [0, 1] based on the original image dimensions.\n\n Examples:\n >>> boxes = Boxes(torch.tensor([[100, 50, 150, 100, 0.9, 0]]), orig_shape=(480, 640))\n >>> normalized = boxes.xywhn\n >>> print(normalized)\n tensor([[0.1953, 0.1562, 0.0781, 0.1042]])\n \"\"\"\n xywh = ops.xyxy2xywh(self.xyxy)\n xywh[..., [0, 2]] /= self.orig_shape[1]\n xywh[..., [1, 3]] /= self.orig_shape[0]\n return xywh", "chunk_type": "class", "name": "Boxes", "file_path": "ultralytics\\ultralytics\\engine\\results.py", "start_line": 856, "end_line": 1072, "start_col": 0, "end_col": 19, "parent_name": null, "docstring": "A class for managing and manipulating detection boxes.\n\nThis class provides comprehensive functionality for handling detection boxes, including their coordinates,\nconfidence scores, class labels, and optional tracking IDs. It supports various box formats and offers\nmethods for easy manipulation and conversion between different coordinate systems.\n\nAttributes:\n data (torch.Tensor | np.ndarray): The raw tensor containing detection boxes and associated data.\n orig_shape (Tuple[int, int]): The original image dimensions (height, width).\n is_track (bool): Indicates whether tracking IDs are included in the box data.\n xyxy (torch.Tensor | np.ndarray): Boxes in [x1, y1, x2, y2] format.\n conf (torch.Tensor | np.ndarray): Confidence scores for each box.\n cls (torch.Tensor | np.ndarray): Class labels for each box.\n id (torch.Tensor | None): Tracking IDs for each box (if available).\n xywh (torch.Tensor | np.ndarray): Boxes in [x, y, width, height] format.\n xyxyn (torch.Tensor | np.ndarray): Normalized [x1, y1, x2, y2] boxes relative to orig_shape.\n xywhn (torch.Tensor | np.ndarray): Normalized [x, y, width, height] boxes relative to orig_shape.\n\nMethods:\n cpu: Return a copy of the object with all tensors on CPU memory.\n numpy: Return a copy of the object with all tensors as numpy arrays.\n cuda: Return a copy of the object with all tensors on GPU memory.\n to: Return a copy of the object with tensors on specified device and dtype.\n\nExamples:\n >>> import torch\n >>> boxes_data = torch.tensor([[100, 50, 150, 100, 0.9, 0], [200, 150, 300, 250, 0.8, 1]])\n >>> orig_shape = (480, 640) # height, width\n >>> boxes = Boxes(boxes_data, orig_shape)\n >>> print(boxes.xyxy)\n >>> print(boxes.conf)\n >>> print(boxes.cls)\n >>> print(boxes.xywhn)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "copy.deepcopy", "functools.lru_cache", "pathlib.Path", "typing.Any", "typing.Dict", "typing.List", "typing.Optional", "typing.Tuple", "typing.Union", "numpy", "torch", "ultralytics.data.augment.LetterBox", "ultralytics.utils.LOGGER", "ultralytics.utils.DataExportMixin", "ultralytics.utils.SimpleClass", "ultralytics.utils.ops", "ultralytics.utils.plotting.Annotator", "ultralytics.utils.plotting.colors", "ultralytics.utils.plotting.save_one_box", "BaseTensor" ], "chunk_id": "class_Boxes_a69d7de2" }, { "content": "class Masks(BaseTensor):\n \"\"\"\n A class for storing and manipulating detection masks.\n\n This class extends BaseTensor and provides functionality for handling segmentation masks,\n including methods for converting between pixel and normalized coordinates.\n\n Attributes:\n data (torch.Tensor | np.ndarray): The raw tensor or array containing mask data.\n orig_shape (tuple): Original image shape in (height, width) format.\n xy (List[np.ndarray]): A list of segments in pixel coordinates.\n xyn (List[np.ndarray]): A list of normalized segments.\n\n Methods:\n cpu: Return a copy of the Masks object with the mask tensor on CPU memory.\n numpy: Return a copy of the Masks object with the mask tensor as a numpy array.\n cuda: Return a copy of the Masks object with the mask tensor on GPU memory.\n to: Return a copy of the Masks object with the mask tensor on specified device and dtype.\n\n Examples:\n >>> masks_data = torch.rand(1, 160, 160)\n >>> orig_shape = (720, 1280)\n >>> masks = Masks(masks_data, orig_shape)\n >>> pixel_coords = masks.xy\n >>> normalized_coords = masks.xyn\n \"\"\"\n\n def __init__(self, masks: Union[torch.Tensor, np.ndarray], orig_shape: Tuple[int, int]) -> None:\n \"\"\"\n Initialize the Masks class with detection mask data and the original image shape.\n\n Args:\n masks (torch.Tensor | np.ndarray): Detection masks with shape (num_masks, height, width).\n orig_shape (tuple): The original image shape as (height, width). Used for normalization.\n\n Examples:\n >>> import torch\n >>> from ultralytics.engine.results import Masks\n >>> masks = torch.rand(10, 160, 160) # 10 masks of 160x160 resolution\n >>> orig_shape = (720, 1280) # Original image shape\n >>> mask_obj = Masks(masks, orig_shape)\n \"\"\"\n if masks.ndim == 2:\n masks = masks[None, :]\n super().__init__(masks, orig_shape)\n\n @property\n @lru_cache(maxsize=1)\n def xyn(self) -> List[np.ndarray]:\n \"\"\"\n Return normalized xy-coordinates of the segmentation masks.\n\n This property calculates and caches the normalized xy-coordinates of the segmentation masks. The coordinates\n are normalized relative to the original image shape.\n\n Returns:\n (List[np.ndarray]): A list of numpy arrays, where each array contains the normalized xy-coordinates\n of a single segmentation mask. Each array has shape (N, 2), where N is the number of points in the\n mask contour.\n\n Examples:\n >>> results = model(\"image.jpg\")\n >>> masks = results[0].masks\n >>> normalized_coords = masks.xyn\n >>> print(normalized_coords[0]) # Normalized coordinates of the first mask\n \"\"\"\n return [\n ops.scale_coords(self.data.shape[1:], x, self.orig_shape, normalize=True)\n for x in ops.masks2segments(self.data)\n ]\n\n @property\n @lru_cache(maxsize=1)\n def xy(self) -> List[np.ndarray]:\n \"\"\"\n Return the [x, y] pixel coordinates for each segment in the mask tensor.\n\n This property calculates and returns a list of pixel coordinates for each segmentation mask in the\n Masks object. The coordinates are scaled to match the original image dimensions.\n\n Returns:\n (List[np.ndarray]): A list of numpy arrays, where each array contains the [x, y] pixel\n coordinates for a single segmentation mask. Each array has shape (N, 2), where N is the\n number of points in the segment.\n\n Examples:\n >>> results = model(\"image.jpg\")\n >>> masks = results[0].masks\n >>> xy_coords = masks.xy\n >>> print(len(xy_coords)) # Number of masks\n >>> print(xy_coords[0].shape) # Shape of first mask's coordinates\n \"\"\"\n return [\n ops.scale_coords(self.data.shape[1:], x, self.orig_shape, normalize=False)\n for x in ops.masks2segments(self.data)\n ]", "chunk_type": "class", "name": "Masks", "file_path": "ultralytics\\ultralytics\\engine\\results.py", "start_line": 1075, "end_line": 1170, "start_col": 0, "end_col": 9, "parent_name": null, "docstring": "A class for storing and manipulating detection masks.\n\nThis class extends BaseTensor and provides functionality for handling segmentation masks,\nincluding methods for converting between pixel and normalized coordinates.\n\nAttributes:\n data (torch.Tensor | np.ndarray): The raw tensor or array containing mask data.\n orig_shape (tuple): Original image shape in (height, width) format.\n xy (List[np.ndarray]): A list of segments in pixel coordinates.\n xyn (List[np.ndarray]): A list of normalized segments.\n\nMethods:\n cpu: Return a copy of the Masks object with the mask tensor on CPU memory.\n numpy: Return a copy of the Masks object with the mask tensor as a numpy array.\n cuda: Return a copy of the Masks object with the mask tensor on GPU memory.\n to: Return a copy of the Masks object with the mask tensor on specified device and dtype.\n\nExamples:\n >>> masks_data = torch.rand(1, 160, 160)\n >>> orig_shape = (720, 1280)\n >>> masks = Masks(masks_data, orig_shape)\n >>> pixel_coords = masks.xy\n >>> normalized_coords = masks.xyn", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "copy.deepcopy", "functools.lru_cache", "pathlib.Path", "typing.Any", "typing.Dict", "typing.List", "typing.Optional", "typing.Tuple", "typing.Union", "numpy", "torch", "ultralytics.data.augment.LetterBox", "ultralytics.utils.LOGGER", "ultralytics.utils.DataExportMixin", "ultralytics.utils.SimpleClass", "ultralytics.utils.ops", "ultralytics.utils.plotting.Annotator", "ultralytics.utils.plotting.colors", "ultralytics.utils.plotting.save_one_box", "BaseTensor" ], "chunk_id": "class_Masks_40d9497b" }, { "content": "class Keypoints(BaseTensor):\n \"\"\"\n A class for storing and manipulating detection keypoints.\n\n This class encapsulates functionality for handling keypoint data, including coordinate manipulation,\n normalization, and confidence values. It supports keypoint detection results with optional visibility\n information.\n\n Attributes:\n data (torch.Tensor): The raw tensor containing keypoint data.\n orig_shape (Tuple[int, int]): The original image dimensions (height, width).\n has_visible (bool): Indicates whether visibility information is available for keypoints.\n xy (torch.Tensor): Keypoint coordinates in [x, y] format.\n xyn (torch.Tensor): Normalized keypoint coordinates in [x, y] format, relative to orig_shape.\n conf (torch.Tensor): Confidence values for each keypoint, if available.\n\n Methods:\n cpu: Return a copy of the keypoints tensor on CPU memory.\n numpy: Return a copy of the keypoints tensor as a numpy array.\n cuda: Return a copy of the keypoints tensor on GPU memory.\n to: Return a copy of the keypoints tensor with specified device and dtype.\n\n Examples:\n >>> import torch\n >>> from ultralytics.engine.results import Keypoints\n >>> keypoints_data = torch.rand(1, 17, 3) # 1 detection, 17 keypoints, (x, y, conf)\n >>> orig_shape = (480, 640) # Original image shape (height, width)\n >>> keypoints = Keypoints(keypoints_data, orig_shape)\n >>> print(keypoints.xy.shape) # Access xy coordinates\n >>> print(keypoints.conf) # Access confidence values\n >>> keypoints_cpu = keypoints.cpu() # Move keypoints to CPU\n \"\"\"\n\n def __init__(self, keypoints: Union[torch.Tensor, np.ndarray], orig_shape: Tuple[int, int]) -> None:\n \"\"\"\n Initialize the Keypoints object with detection keypoints and original image dimensions.\n\n This method processes the input keypoints tensor, handling both 2D and 3D formats. For 3D tensors\n (x, y, confidence), it masks out low-confidence keypoints by setting their coordinates to zero.\n\n Args:\n keypoints (torch.Tensor): A tensor containing keypoint data. Shape can be either:\n - (num_objects, num_keypoints, 2) for x, y coordinates only\n - (num_objects, num_keypoints, 3) for x, y coordinates and confidence scores\n orig_shape (Tuple[int, int]): The original image dimensions (height, width).\n\n Examples:\n >>> kpts = torch.rand(1, 17, 3) # 1 object, 17 keypoints (COCO format), x,y,conf\n >>> orig_shape = (720, 1280) # Original image height, width\n >>> keypoints = Keypoints(kpts, orig_shape)\n \"\"\"\n if keypoints.ndim == 2:\n keypoints = keypoints[None, :]\n super().__init__(keypoints, orig_shape)\n self.has_visible = self.data.shape[-1] == 3\n\n @property\n @lru_cache(maxsize=1)\n def xy(self) -> Union[torch.Tensor, np.ndarray]:\n \"\"\"\n Return x, y coordinates of keypoints.\n\n Returns:\n (torch.Tensor): A tensor containing the x, y coordinates of keypoints with shape (N, K, 2), where N is\n the number of detections and K is the number of keypoints per detection.\n\n Examples:\n >>> results = model(\"image.jpg\")\n >>> keypoints = results[0].keypoints\n >>> xy = keypoints.xy\n >>> print(xy.shape) # (N, K, 2)\n >>> print(xy[0]) # x, y coordinates of keypoints for first detection\n\n Notes:\n - The returned coordinates are in pixel units relative to the original image dimensions.\n - If keypoints were initialized with confidence values, only keypoints with confidence >= 0.5 are returned.\n - This property uses LRU caching to improve performance on repeated access.\n \"\"\"\n return self.data[..., :2]\n\n @property\n @lru_cache(maxsize=1)\n def xyn(self) -> Union[torch.Tensor, np.ndarray]:\n \"\"\"\n Return normalized coordinates (x, y) of keypoints relative to the original image size.\n\n Returns:\n (torch.Tensor | np.ndarray): A tensor or array of shape (N, K, 2) containing normalized keypoint\n coordinates, where N is the number of instances, K is the number of keypoints, and the last\n dimension contains [x, y] values in the range [0, 1].\n\n Examples:\n >>> keypoints = Keypoints(torch.rand(1, 17, 2), orig_shape=(480, 640))\n >>> normalized_kpts = keypoints.xyn\n >>> print(normalized_kpts.shape)\n torch.Size([1, 17, 2])\n \"\"\"\n xy = self.xy.clone() if isinstance(self.xy, torch.Tensor) else np.copy(self.xy)\n xy[..., 0] /= self.orig_shape[1]\n xy[..., 1] /= self.orig_shape[0]\n return xy\n\n @property\n @lru_cache(maxsize=1)\n def conf(self) -> Optional[Union[torch.Tensor, np.ndarray]]:\n \"\"\"\n Return confidence values for each keypoint.\n\n Returns:\n (torch.Tensor | None): A tensor containing confidence scores for each keypoint if available,\n otherwise None. Shape is (num_detections, num_keypoints) for batched data or (num_keypoints,)\n for single detection.\n\n Examples:\n >>> keypoints = Keypoints(torch.rand(1, 17, 3), orig_shape=(640, 640)) # 1 detection, 17 keypoints\n >>> conf = keypoints.conf\n >>> print(conf.shape) # torch.Size([1, 17])\n \"\"\"\n return self.data[..., 2] if self.has_visible else None", "chunk_type": "class", "name": "Keypoints", "file_path": "ultralytics\\ultralytics\\engine\\results.py", "start_line": 1173, "end_line": 1291, "start_col": 0, "end_col": 62, "parent_name": null, "docstring": "A class for storing and manipulating detection keypoints.\n\nThis class encapsulates functionality for handling keypoint data, including coordinate manipulation,\nnormalization, and confidence values. It supports keypoint detection results with optional visibility\ninformation.\n\nAttributes:\n data (torch.Tensor): The raw tensor containing keypoint data.\n orig_shape (Tuple[int, int]): The original image dimensions (height, width).\n has_visible (bool): Indicates whether visibility information is available for keypoints.\n xy (torch.Tensor): Keypoint coordinates in [x, y] format.\n xyn (torch.Tensor): Normalized keypoint coordinates in [x, y] format, relative to orig_shape.\n conf (torch.Tensor): Confidence values for each keypoint, if available.\n\nMethods:\n cpu: Return a copy of the keypoints tensor on CPU memory.\n numpy: Return a copy of the keypoints tensor as a numpy array.\n cuda: Return a copy of the keypoints tensor on GPU memory.\n to: Return a copy of the keypoints tensor with specified device and dtype.\n\nExamples:\n >>> import torch\n >>> from ultralytics.engine.results import Keypoints\n >>> keypoints_data = torch.rand(1, 17, 3) # 1 detection, 17 keypoints, (x, y, conf)\n >>> orig_shape = (480, 640) # Original image shape (height, width)\n >>> keypoints = Keypoints(keypoints_data, orig_shape)\n >>> print(keypoints.xy.shape) # Access xy coordinates\n >>> print(keypoints.conf) # Access confidence values\n >>> keypoints_cpu = keypoints.cpu() # Move keypoints to CPU", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "copy.deepcopy", "functools.lru_cache", "pathlib.Path", "typing.Any", "typing.Dict", "typing.List", "typing.Optional", "typing.Tuple", "typing.Union", "numpy", "torch", "ultralytics.data.augment.LetterBox", "ultralytics.utils.LOGGER", "ultralytics.utils.DataExportMixin", "ultralytics.utils.SimpleClass", "ultralytics.utils.ops", "ultralytics.utils.plotting.Annotator", "ultralytics.utils.plotting.colors", "ultralytics.utils.plotting.save_one_box", "BaseTensor" ], "chunk_id": "class_Keypoints_18f65f76" }, { "content": "class Probs(BaseTensor):\n \"\"\"\n A class for storing and manipulating classification probabilities.\n\n This class extends BaseTensor and provides methods for accessing and manipulating\n classification probabilities, including top-1 and top-5 predictions.\n\n Attributes:\n data (torch.Tensor | np.ndarray): The raw tensor or array containing classification probabilities.\n orig_shape (tuple | None): The original image shape as (height, width). Not used in this class.\n top1 (int): Index of the class with the highest probability.\n top5 (List[int]): Indices of the top 5 classes by probability.\n top1conf (torch.Tensor | np.ndarray): Confidence score of the top 1 class.\n top5conf (torch.Tensor | np.ndarray): Confidence scores of the top 5 classes.\n\n Methods:\n cpu: Return a copy of the probabilities tensor on CPU memory.\n numpy: Return a copy of the probabilities tensor as a numpy array.\n cuda: Return a copy of the probabilities tensor on GPU memory.\n to: Return a copy of the probabilities tensor with specified device and dtype.\n\n Examples:\n >>> probs = torch.tensor([0.1, 0.3, 0.6])\n >>> p = Probs(probs)\n >>> print(p.top1)\n 2\n >>> print(p.top5)\n [2, 1, 0]\n >>> print(p.top1conf)\n tensor(0.6000)\n >>> print(p.top5conf)\n tensor([0.6000, 0.3000, 0.1000])\n \"\"\"\n\n def __init__(self, probs: Union[torch.Tensor, np.ndarray], orig_shape: Optional[Tuple[int, int]] = None) -> None:\n \"\"\"\n Initialize the Probs class with classification probabilities.\n\n This class stores and manages classification probabilities, providing easy access to top predictions and their\n confidences.\n\n Args:\n probs (torch.Tensor | np.ndarray): A 1D tensor or array of classification probabilities.\n orig_shape (tuple | None): The original image shape as (height, width). Not used in this class but kept\n for consistency with other result classes.\n\n Attributes:\n data (torch.Tensor | np.ndarray): The raw tensor or array containing classification probabilities.\n top1 (int): Index of the top 1 class.\n top5 (List[int]): Indices of the top 5 classes.\n top1conf (torch.Tensor | np.ndarray): Confidence of the top 1 class.\n top5conf (torch.Tensor | np.ndarray): Confidences of the top 5 classes.\n\n Examples:\n >>> import torch\n >>> probs = torch.tensor([0.1, 0.3, 0.2, 0.4])\n >>> p = Probs(probs)\n >>> print(p.top1)\n 3\n >>> print(p.top1conf)\n tensor(0.4000)\n >>> print(p.top5)\n [3, 1, 2, 0]\n \"\"\"\n super().__init__(probs, orig_shape)\n\n @property\n @lru_cache(maxsize=1)\n def top1(self) -> int:\n \"\"\"\n Return the index of the class with the highest probability.\n\n Returns:\n (int): Index of the class with the highest probability.\n\n Examples:\n >>> probs = Probs(torch.tensor([0.1, 0.3, 0.6]))\n >>> probs.top1\n 2\n \"\"\"\n return int(self.data.argmax())\n\n @property\n @lru_cache(maxsize=1)\n def top5(self) -> List[int]:\n \"\"\"\n Return the indices of the top 5 class probabilities.\n\n Returns:\n (List[int]): A list containing the indices of the top 5 class probabilities, sorted in descending order.\n\n Examples:\n >>> probs = Probs(torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5]))\n >>> print(probs.top5)\n [4, 3, 2, 1, 0]\n \"\"\"\n return (-self.data).argsort(0)[:5].tolist() # this way works with both torch and numpy.\n\n @property\n @lru_cache(maxsize=1)\n def top1conf(self) -> Union[torch.Tensor, np.ndarray]:\n \"\"\"\n Return the confidence score of the highest probability class.\n\n This property retrieves the confidence score (probability) of the class with the highest predicted probability\n from the classification results.\n\n Returns:\n (torch.Tensor | np.ndarray): A tensor containing the confidence score of the top 1 class.\n\n Examples:\n >>> results = model(\"image.jpg\") # classify an image\n >>> probs = results[0].probs # get classification probabilities\n >>> top1_confidence = probs.top1conf # get confidence of top 1 class\n >>> print(f\"Top 1 class confidence: {top1_confidence.item():.4f}\")\n \"\"\"\n return self.data[self.top1]\n\n @property\n @lru_cache(maxsize=1)\n def top5conf(self) -> Union[torch.Tensor, np.ndarray]:\n \"\"\"\n Return confidence scores for the top 5 classification predictions.\n\n This property retrieves the confidence scores corresponding to the top 5 class probabilities\n predicted by the model. It provides a quick way to access the most likely class predictions\n along with their associated confidence levels.\n\n Returns:\n (torch.Tensor | np.ndarray): A tensor or array containing the confidence scores for the\n top 5 predicted classes, sorted in descending order of probability.\n\n Examples:\n >>> results = model(\"image.jpg\")\n >>> probs = results[0].probs\n >>> top5_conf = probs.top5conf\n >>> print(top5_conf) # Prints confidence scores for top 5 classes\n \"\"\"\n return self.data[self.top5]", "chunk_type": "class", "name": "Probs", "file_path": "ultralytics\\ultralytics\\engine\\results.py", "start_line": 1294, "end_line": 1432, "start_col": 0, "end_col": 35, "parent_name": null, "docstring": "A class for storing and manipulating classification probabilities.\n\nThis class extends BaseTensor and provides methods for accessing and manipulating\nclassification probabilities, including top-1 and top-5 predictions.\n\nAttributes:\n data (torch.Tensor | np.ndarray): The raw tensor or array containing classification probabilities.\n orig_shape (tuple | None): The original image shape as (height, width). Not used in this class.\n top1 (int): Index of the class with the highest probability.\n top5 (List[int]): Indices of the top 5 classes by probability.\n top1conf (torch.Tensor | np.ndarray): Confidence score of the top 1 class.\n top5conf (torch.Tensor | np.ndarray): Confidence scores of the top 5 classes.\n\nMethods:\n cpu: Return a copy of the probabilities tensor on CPU memory.\n numpy: Return a copy of the probabilities tensor as a numpy array.\n cuda: Return a copy of the probabilities tensor on GPU memory.\n to: Return a copy of the probabilities tensor with specified device and dtype.\n\nExamples:\n >>> probs = torch.tensor([0.1, 0.3, 0.6])\n >>> p = Probs(probs)\n >>> print(p.top1)\n 2\n >>> print(p.top5)\n [2, 1, 0]\n >>> print(p.top1conf)\n tensor(0.6000)\n >>> print(p.top5conf)\n tensor([0.6000, 0.3000, 0.1000])", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "copy.deepcopy", "functools.lru_cache", "pathlib.Path", "typing.Any", "typing.Dict", "typing.List", "typing.Optional", "typing.Tuple", "typing.Union", "numpy", "torch", "ultralytics.data.augment.LetterBox", "ultralytics.utils.LOGGER", "ultralytics.utils.DataExportMixin", "ultralytics.utils.SimpleClass", "ultralytics.utils.ops", "ultralytics.utils.plotting.Annotator", "ultralytics.utils.plotting.colors", "ultralytics.utils.plotting.save_one_box", "BaseTensor" ], "chunk_id": "class_Probs_df508b0d" }, { "content": "class OBB(BaseTensor):\n \"\"\"\n A class for storing and manipulating Oriented Bounding Boxes (OBB).\n\n This class provides functionality to handle oriented bounding boxes, including conversion between\n different formats, normalization, and access to various properties of the boxes. It supports\n both tracking and non-tracking scenarios.\n\n Attributes:\n data (torch.Tensor): The raw OBB tensor containing box coordinates and associated data.\n orig_shape (tuple): Original image size as (height, width).\n is_track (bool): Indicates whether tracking IDs are included in the box data.\n xywhr (torch.Tensor | np.ndarray): Boxes in [x_center, y_center, width, height, rotation] format.\n conf (torch.Tensor | np.ndarray): Confidence scores for each box.\n cls (torch.Tensor | np.ndarray): Class labels for each box.\n id (torch.Tensor | np.ndarray): Tracking IDs for each box, if available.\n xyxyxyxy (torch.Tensor | np.ndarray): Boxes in 8-point [x1, y1, x2, y2, x3, y3, x4, y4] format.\n xyxyxyxyn (torch.Tensor | np.ndarray): Normalized 8-point coordinates relative to orig_shape.\n xyxy (torch.Tensor | np.ndarray): Axis-aligned bounding boxes in [x1, y1, x2, y2] format.\n\n Methods:\n cpu: Return a copy of the OBB object with all tensors on CPU memory.\n numpy: Return a copy of the OBB object with all tensors as numpy arrays.\n cuda: Return a copy of the OBB object with all tensors on GPU memory.\n to: Return a copy of the OBB object with tensors on specified device and dtype.\n\n Examples:\n >>> boxes = torch.tensor([[100, 50, 150, 100, 30, 0.9, 0]]) # xywhr, conf, cls\n >>> obb = OBB(boxes, orig_shape=(480, 640))\n >>> print(obb.xyxyxyxy)\n >>> print(obb.conf)\n >>> print(obb.cls)\n \"\"\"\n\n def __init__(self, boxes: Union[torch.Tensor, np.ndarray], orig_shape: Tuple[int, int]) -> None:\n \"\"\"\n Initialize an OBB (Oriented Bounding Box) instance with oriented bounding box data and original image shape.\n\n This class stores and manipulates Oriented Bounding Boxes (OBB) for object detection tasks. It provides\n various properties and methods to access and transform the OBB data.\n\n Args:\n boxes (torch.Tensor | np.ndarray): A tensor or numpy array containing the detection boxes,\n with shape (num_boxes, 7) or (num_boxes, 8). The last two columns contain confidence and class values.\n If present, the third last column contains track IDs, and the fifth column contains rotation.\n orig_shape (Tuple[int, int]): Original image size, in the format (height, width).\n\n Attributes:\n data (torch.Tensor | np.ndarray): The raw OBB tensor.\n orig_shape (Tuple[int, int]): The original image shape.\n is_track (bool): Whether the boxes include tracking IDs.\n\n Raises:\n AssertionError: If the number of values per box is not 7 or 8.\n\n Examples:\n >>> import torch\n >>> boxes = torch.rand(3, 7) # 3 boxes with 7 values each\n >>> orig_shape = (640, 480)\n >>> obb = OBB(boxes, orig_shape)\n >>> print(obb.xywhr) # Access the boxes in xywhr format\n \"\"\"\n if boxes.ndim == 1:\n boxes = boxes[None, :]\n n = boxes.shape[-1]\n assert n in {7, 8}, f\"expected 7 or 8 values but got {n}\" # xywh, rotation, track_id, conf, cls\n super().__init__(boxes, orig_shape)\n self.is_track = n == 8\n self.orig_shape = orig_shape\n\n @property\n def xywhr(self) -> Union[torch.Tensor, np.ndarray]:\n \"\"\"\n Return boxes in [x_center, y_center, width, height, rotation] format.\n\n Returns:\n (torch.Tensor | np.ndarray): A tensor or numpy array containing the oriented bounding boxes with format\n [x_center, y_center, width, height, rotation]. The shape is (N, 5) where N is the number of boxes.\n\n Examples:\n >>> results = model(\"image.jpg\")\n >>> obb = results[0].obb\n >>> xywhr = obb.xywhr\n >>> print(xywhr.shape)\n torch.Size([3, 5])\n \"\"\"\n return self.data[:, :5]\n\n @property\n def conf(self) -> Union[torch.Tensor, np.ndarray]:\n \"\"\"\n Return the confidence scores for Oriented Bounding Boxes (OBBs).\n\n This property retrieves the confidence values associated with each OBB detection. The confidence score\n represents the model's certainty in the detection.\n\n Returns:\n (torch.Tensor | np.ndarray): A tensor or numpy array of shape (N,) containing confidence scores\n for N detections, where each score is in the range [0, 1].\n\n Examples:\n >>> results = model(\"image.jpg\")\n >>> obb_result = results[0].obb\n >>> confidence_scores = obb_result.conf\n >>> print(confidence_scores)\n \"\"\"\n return self.data[:, -2]\n\n @property\n def cls(self) -> Union[torch.Tensor, np.ndarray]:\n \"\"\"\n Return the class values of the oriented bounding boxes.\n\n Returns:\n (torch.Tensor | np.ndarray): A tensor or numpy array containing the class values for each oriented\n bounding box. The shape is (N,), where N is the number of boxes.\n\n Examples:\n >>> results = model(\"image.jpg\")\n >>> result = results[0]\n >>> obb = result.obb\n >>> class_values = obb.cls\n >>> print(class_values)\n \"\"\"\n return self.data[:, -1]\n\n @property\n def id(self) -> Optional[Union[torch.Tensor, np.ndarray]]:\n \"\"\"\n Return the tracking IDs of the oriented bounding boxes (if available).\n\n Returns:\n (torch.Tensor | np.ndarray | None): A tensor or numpy array containing the tracking IDs for each\n oriented bounding box. Returns None if tracking IDs are not available.\n\n Examples:\n >>> results = model(\"image.jpg\", tracker=True) # Run inference with tracking\n >>> for result in results:\n ... if result.obb is not None:\n ... track_ids = result.obb.id\n ... if track_ids is not None:\n ... print(f\"Tracking IDs: {track_ids}\")\n \"\"\"\n return self.data[:, -3] if self.is_track else None\n\n @property\n @lru_cache(maxsize=2)\n def xyxyxyxy(self) -> Union[torch.Tensor, np.ndarray]:\n \"\"\"\n Convert OBB format to 8-point (xyxyxyxy) coordinate format for rotated bounding boxes.\n\n Returns:\n (torch.Tensor | np.ndarray): Rotated bounding boxes in xyxyxyxy format with shape (N, 4, 2), where N is\n the number of boxes. Each box is represented by 4 points (x, y), starting from the top-left corner and\n moving clockwise.\n\n Examples:\n >>> obb = OBB(torch.tensor([[100, 100, 50, 30, 0.5, 0.9, 0]]), orig_shape=(640, 640))\n >>> xyxyxyxy = obb.xyxyxyxy\n >>> print(xyxyxyxy.shape)\n torch.Size([1, 4, 2])\n \"\"\"\n return ops.xywhr2xyxyxyxy(self.xywhr)\n\n @property\n @lru_cache(maxsize=2)\n def xyxyxyxyn(self) -> Union[torch.Tensor, np.ndarray]:\n \"\"\"\n Convert rotated bounding boxes to normalized xyxyxyxy format.\n\n Returns:\n (torch.Tensor | np.ndarray): Normalized rotated bounding boxes in xyxyxyxy format with shape (N, 4, 2),\n where N is the number of boxes. Each box is represented by 4 points (x, y), normalized relative to\n the original image dimensions.\n\n Examples:\n >>> obb = OBB(torch.rand(10, 7), orig_shape=(640, 480)) # 10 random OBBs\n >>> normalized_boxes = obb.xyxyxyxyn\n >>> print(normalized_boxes.shape)\n torch.Size([10, 4, 2])\n \"\"\"\n xyxyxyxyn = self.xyxyxyxy.clone() if isinstance(self.xyxyxyxy, torch.Tensor) else np.copy(self.xyxyxyxy)\n xyxyxyxyn[..., 0] /= self.orig_shape[1]\n xyxyxyxyn[..., 1] /= self.orig_shape[0]\n return xyxyxyxyn\n\n @property\n @lru_cache(maxsize=2)\n def xyxy(self) -> Union[torch.Tensor, np.ndarray]:\n \"\"\"\n Convert oriented bounding boxes (OBB) to axis-aligned bounding boxes in xyxy format.\n\n This property calculates the minimal enclosing rectangle for each oriented bounding box and returns it in\n xyxy format (x1, y1, x2, y2). This is useful for operations that require axis-aligned bounding boxes, such\n as IoU calculation with non-rotated boxes.\n\n Returns:\n (torch.Tensor | np.ndarray): Axis-aligned bounding boxes in xyxy format with shape (N, 4), where N\n is the number of boxes. Each row contains [x1, y1, x2, y2] coordinates.\n\n Examples:\n >>> import torch\n >>> from ultralytics import YOLO\n >>> model = YOLO(\"yolo11n-obb.pt\")\n >>> results = model(\"path/to/image.jpg\")\n >>> for result in results:\n ... obb = result.obb\n ... if obb is not None:\n ... xyxy_boxes = obb.xyxy\n ... print(xyxy_boxes.shape) # (N, 4)\n\n Notes:\n - This method approximates the OBB by its minimal enclosing rectangle.\n - The returned format is compatible with standard object detection metrics and visualization tools.\n - The property uses caching to improve performance for repeated access.\n \"\"\"\n x = self.xyxyxyxy[..., 0]\n y = self.xyxyxyxy[..., 1]\n return (\n torch.stack([x.amin(1), y.amin(1), x.amax(1), y.amax(1)], -1)\n if isinstance(x, torch.Tensor)\n else np.stack([x.min(1), y.min(1), x.max(1), y.max(1)], -1)\n )", "chunk_type": "class", "name": "OBB", "file_path": "ultralytics\\ultralytics\\engine\\results.py", "start_line": 1435, "end_line": 1657, "start_col": 0, "end_col": 9, "parent_name": null, "docstring": "A class for storing and manipulating Oriented Bounding Boxes (OBB).\n\nThis class provides functionality to handle oriented bounding boxes, including conversion between\ndifferent formats, normalization, and access to various properties of the boxes. It supports\nboth tracking and non-tracking scenarios.\n\nAttributes:\n data (torch.Tensor): The raw OBB tensor containing box coordinates and associated data.\n orig_shape (tuple): Original image size as (height, width).\n is_track (bool): Indicates whether tracking IDs are included in the box data.\n xywhr (torch.Tensor | np.ndarray): Boxes in [x_center, y_center, width, height, rotation] format.\n conf (torch.Tensor | np.ndarray): Confidence scores for each box.\n cls (torch.Tensor | np.ndarray): Class labels for each box.\n id (torch.Tensor | np.ndarray): Tracking IDs for each box, if available.\n xyxyxyxy (torch.Tensor | np.ndarray): Boxes in 8-point [x1, y1, x2, y2, x3, y3, x4, y4] format.\n xyxyxyxyn (torch.Tensor | np.ndarray): Normalized 8-point coordinates relative to orig_shape.\n xyxy (torch.Tensor | np.ndarray): Axis-aligned bounding boxes in [x1, y1, x2, y2] format.\n\nMethods:\n cpu: Return a copy of the OBB object with all tensors on CPU memory.\n numpy: Return a copy of the OBB object with all tensors as numpy arrays.\n cuda: Return a copy of the OBB object with all tensors on GPU memory.\n to: Return a copy of the OBB object with tensors on specified device and dtype.\n\nExamples:\n >>> boxes = torch.tensor([[100, 50, 150, 100, 30, 0.9, 0]]) # xywhr, conf, cls\n >>> obb = OBB(boxes, orig_shape=(480, 640))\n >>> print(obb.xyxyxyxy)\n >>> print(obb.conf)\n >>> print(obb.cls)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "copy.deepcopy", "functools.lru_cache", "pathlib.Path", "typing.Any", "typing.Dict", "typing.List", "typing.Optional", "typing.Tuple", "typing.Union", "numpy", "torch", "ultralytics.data.augment.LetterBox", "ultralytics.utils.LOGGER", "ultralytics.utils.DataExportMixin", "ultralytics.utils.SimpleClass", "ultralytics.utils.ops", "ultralytics.utils.plotting.Annotator", "ultralytics.utils.plotting.colors", "ultralytics.utils.plotting.save_one_box", "BaseTensor" ], "chunk_id": "class_OBB_89505182" }, { "content": "import gc", "chunk_type": "import", "name": "gc", "file_path": "ultralytics\\ultralytics\\engine\\trainer.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 9, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_gc_a22d7e33" }, { "content": "import math", "chunk_type": "import", "name": "math", "file_path": "ultralytics\\ultralytics\\engine\\trainer.py", "start_line": 10, "end_line": 10, "start_col": 0, "end_col": 11, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_math_0080b56b" }, { "content": "import os", "chunk_type": "import", "name": "os", "file_path": "ultralytics\\ultralytics\\engine\\trainer.py", "start_line": 11, "end_line": 11, "start_col": 0, "end_col": 9, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_os_3998dd71" }, { "content": "import subprocess", "chunk_type": "import", "name": "subprocess", "file_path": "ultralytics\\ultralytics\\engine\\trainer.py", "start_line": 12, "end_line": 12, "start_col": 0, "end_col": 17, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_subprocess_af901d35" }, { "content": "import time", "chunk_type": "import", "name": "time", "file_path": "ultralytics\\ultralytics\\engine\\trainer.py", "start_line": 13, "end_line": 13, "start_col": 0, "end_col": 11, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_time_662896b3" }, { "content": "import warnings", "chunk_type": "import", "name": "warnings", "file_path": "ultralytics\\ultralytics\\engine\\trainer.py", "start_line": 14, "end_line": 14, "start_col": 0, "end_col": 15, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_warnings_76e0a7cf" }, { "content": "from copy import copy, deepcopy", "chunk_type": "import", "name": "copy, deepcopy", "file_path": "ultralytics\\ultralytics\\engine\\trainer.py", "start_line": 15, "end_line": 15, "start_col": 0, "end_col": 31, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_copy, deepcopy_8f8e6ad4" }, { "content": "from datetime import datetime, timedelta", "chunk_type": "import", "name": "datetime, timedelta", "file_path": "ultralytics\\ultralytics\\engine\\trainer.py", "start_line": 16, "end_line": 16, "start_col": 0, "end_col": 40, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_datetime, timedelta_10cfdbd9" }, { "content": "from pathlib import Path", "chunk_type": "import", "name": "Path", "file_path": "ultralytics\\ultralytics\\engine\\trainer.py", "start_line": 17, "end_line": 17, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Path_2c5819ac" }, { "content": "import numpy as np", "chunk_type": "import", "name": "numpy", "file_path": "ultralytics\\ultralytics\\engine\\trainer.py", "start_line": 19, "end_line": 19, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_numpy_4397cc17" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\ultralytics\\engine\\trainer.py", "start_line": 20, "end_line": 20, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_3cacbc99" }, { "content": "from torch import distributed as dist", "chunk_type": "import", "name": "distributed", "file_path": "ultralytics\\ultralytics\\engine\\trainer.py", "start_line": 21, "end_line": 21, "start_col": 0, "end_col": 37, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_distributed_4cab7084" }, { "content": "from torch import nn, optim", "chunk_type": "import", "name": "nn, optim", "file_path": "ultralytics\\ultralytics\\engine\\trainer.py", "start_line": 22, "end_line": 22, "start_col": 0, "end_col": 27, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_nn, optim_d6b988f2" }, { "content": "from ultralytics import __version__", "chunk_type": "import", "name": "__version__", "file_path": "ultralytics\\ultralytics\\engine\\trainer.py", "start_line": 24, "end_line": 24, "start_col": 0, "end_col": 35, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import___version___c771f5fb" }, { "content": "from ultralytics.cfg import get_cfg, get_save_dir", "chunk_type": "import", "name": "get_cfg, get_save_dir", "file_path": "ultralytics\\ultralytics\\engine\\trainer.py", "start_line": 25, "end_line": 25, "start_col": 0, "end_col": 49, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_get_cfg, get_save_dir_d5048250" }, { "content": "from ultralytics.data.utils import check_cls_dataset, check_det_dataset", "chunk_type": "import", "name": "check_cls_dataset, check_det_dataset", "file_path": "ultralytics\\ultralytics\\engine\\trainer.py", "start_line": 26, "end_line": 26, "start_col": 0, "end_col": 71, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_check_cls_dataset, check_det_dataset_8f2647f9" }, { "content": "from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights", "chunk_type": "import", "name": "attempt_load_one_weight, attempt_load_weights", "file_path": "ultralytics\\ultralytics\\engine\\trainer.py", "start_line": 27, "end_line": 27, "start_col": 0, "end_col": 78, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_attempt_load_one_weight, attempt_load_weights_f56db6b6" }, { "content": "from ultralytics.utils import (\n DEFAULT_CFG,\n LOCAL_RANK,\n LOGGER,\n RANK,\n TQDM,\n YAML,\n callbacks,\n clean_url,\n colorstr,\n emojis,\n)", "chunk_type": "import", "name": "DEFAULT_CFG, LOCAL_RANK, LOGGER, RANK, TQDM, YAML, callbacks, clean_url, colorstr, emojis", "file_path": "ultralytics\\ultralytics\\engine\\trainer.py", "start_line": 28, "end_line": 39, "start_col": 0, "end_col": 1, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_DEFAULT_CFG, LOCAL_RANK, LOGGER, RANK, TQDM, YAML, callbacks, clean_url, colorstr, emojis_27c9a95c" }, { "content": "from ultralytics.utils.autobatch import check_train_batch_size", "chunk_type": "import", "name": "check_train_batch_size", "file_path": "ultralytics\\ultralytics\\engine\\trainer.py", "start_line": 40, "end_line": 40, "start_col": 0, "end_col": 62, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_check_train_batch_size_c2a411e9" }, { "content": "from ultralytics.utils.checks import check_amp, check_file, check_imgsz, check_model_file_from_stem, print_args", "chunk_type": "import", "name": "check_amp, check_file, check_imgsz, check_model_file_from_stem, print_args", "file_path": "ultralytics\\ultralytics\\engine\\trainer.py", "start_line": 41, "end_line": 41, "start_col": 0, "end_col": 111, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_check_amp, check_file, check_imgsz, check_model_file_from_stem, print_args_66b502a9" }, { "content": "from ultralytics.utils.dist import ddp_cleanup, generate_ddp_command", "chunk_type": "import", "name": "ddp_cleanup, generate_ddp_command", "file_path": "ultralytics\\ultralytics\\engine\\trainer.py", "start_line": 42, "end_line": 42, "start_col": 0, "end_col": 68, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_ddp_cleanup, generate_ddp_command_186139cd" }, { "content": "from ultralytics.utils.files import get_latest_run", "chunk_type": "import", "name": "get_latest_run", "file_path": "ultralytics\\ultralytics\\engine\\trainer.py", "start_line": 43, "end_line": 43, "start_col": 0, "end_col": 50, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_get_latest_run_a00d82a8" }, { "content": "from ultralytics.utils.torch_utils import (\n TORCH_2_4,\n EarlyStopping,\n ModelEMA,\n autocast,\n convert_optimizer_state_dict_to_fp16,\n init_seeds,\n one_cycle,\n select_device,\n strip_optimizer,\n torch_distributed_zero_first,\n unset_deterministic,\n)", "chunk_type": "import", "name": "TORCH_2_4, EarlyStopping, ModelEMA, autocast, convert_optimizer_state_dict_to_fp16, init_seeds, one_cycle, select_device, strip_optimizer, torch_distributed_zero_first, unset_deterministic", "file_path": "ultralytics\\ultralytics\\engine\\trainer.py", "start_line": 44, "end_line": 56, "start_col": 0, "end_col": 1, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_TORCH_2_4, EarlyStopping, ModelEMA, autocast, convert_optimizer_state_dict_to_fp16, init_seeds, one_cycle, select_device, strip_optimizer, torch_distributed_zero_first, unset_deterministic_d00ca6c2" }, { "content": "class BaseTrainer:\n \"\"\"\n A base class for creating trainers.\n\n This class provides the foundation for training YOLO models, handling the training loop, validation, checkpointing,\n and various training utilities. It supports both single-GPU and multi-GPU distributed training.\n\n Attributes:\n args (SimpleNamespace): Configuration for the trainer.\n validator (BaseValidator): Validator instance.\n model (nn.Module): Model instance.\n callbacks (defaultdict): Dictionary of callbacks.\n save_dir (Path): Directory to save results.\n wdir (Path): Directory to save weights.\n last (Path): Path to the last checkpoint.\n best (Path): Path to the best checkpoint.\n save_period (int): Save checkpoint every x epochs (disabled if < 1).\n batch_size (int): Batch size for training.\n epochs (int): Number of epochs to train for.\n start_epoch (int): Starting epoch for training.\n device (torch.device): Device to use for training.\n amp (bool): Flag to enable AMP (Automatic Mixed Precision).\n scaler (amp.GradScaler): Gradient scaler for AMP.\n data (str): Path to data.\n ema (nn.Module): EMA (Exponential Moving Average) of the model.\n resume (bool): Resume training from a checkpoint.\n lf (nn.Module): Loss function.\n scheduler (torch.optim.lr_scheduler._LRScheduler): Learning rate scheduler.\n best_fitness (float): The best fitness value achieved.\n fitness (float): Current fitness value.\n loss (float): Current loss value.\n tloss (float): Total loss value.\n loss_names (list): List of loss names.\n csv (Path): Path to results CSV file.\n metrics (dict): Dictionary of metrics.\n plots (dict): Dictionary of plots.\n\n Methods:\n train: Execute the training process.\n validate: Run validation on the test set.\n save_model: Save model training checkpoints.\n get_dataset: Get train and validation datasets.\n setup_model: Load, create, or download model.\n build_optimizer: Construct an optimizer for the model.\n\n Examples:\n Initialize a trainer and start training\n >>> trainer = BaseTrainer(cfg=\"config.yaml\")\n >>> trainer.train()\n \"\"\"\n\n def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):\n \"\"\"\n Initialize the BaseTrainer class.\n\n Args:\n cfg (str, optional): Path to a configuration file.\n overrides (dict, optional): Configuration overrides.\n _callbacks (list, optional): List of callback functions.\n \"\"\"\n self.args = get_cfg(cfg, overrides)\n self.check_resume(overrides)\n self.device = select_device(self.args.device, self.args.batch)\n # Update \"-1\" devices so post-training val does not repeat search\n self.args.device = os.getenv(\"CUDA_VISIBLE_DEVICES\") if \"cuda\" in str(self.device) else str(self.device)\n self.validator = None\n self.metrics = None\n self.plots = {}\n init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)\n\n # Dirs\n self.save_dir = get_save_dir(self.args)\n self.args.name = self.save_dir.name # update name for loggers\n self.wdir = self.save_dir / \"weights\" # weights dir\n if RANK in {-1, 0}:\n self.wdir.mkdir(parents=True, exist_ok=True) # make dir\n self.args.save_dir = str(self.save_dir)\n YAML.save(self.save_dir / \"args.yaml\", vars(self.args)) # save run args\n self.last, self.best = self.wdir / \"last.pt\", self.wdir / \"best.pt\" # checkpoint paths\n self.save_period = self.args.save_period\n\n self.batch_size = self.args.batch\n self.epochs = self.args.epochs or 100 # in case users accidentally pass epochs=None with timed training\n self.start_epoch = 0\n if RANK == -1:\n print_args(vars(self.args))\n\n # Device\n if self.device.type in {\"cpu\", \"mps\"}:\n self.args.workers = 0 # faster CPU training as time dominated by inference, not dataloading\n\n # Model and Dataset\n self.model = check_model_file_from_stem(self.args.model) # add suffix, i.e. yolo11n -> yolo11n.pt\n with torch_distributed_zero_first(LOCAL_RANK): # avoid auto-downloading dataset multiple times\n self.data = self.get_dataset()\n\n self.ema = None\n\n # Optimization utils init\n self.lf = None\n self.scheduler = None\n\n # Epoch level metrics\n self.best_fitness = None\n self.fitness = None\n self.loss = None\n self.tloss = None\n self.loss_names = [\"Loss\"]\n self.csv = self.save_dir / \"results.csv\"\n self.plot_idx = [0, 1, 2]\n\n # HUB\n self.hub_session = None\n\n # Callbacks\n self.callbacks = _callbacks or callbacks.get_default_callbacks()\n if RANK in {-1, 0}:\n callbacks.add_integration_callbacks(self)\n\n def add_callback(self, event: str, callback):\n \"\"\"Append the given callback to the event's callback list.\"\"\"\n self.callbacks[event].append(callback)\n\n def set_callback(self, event: str, callback):\n \"\"\"Override the existing callbacks with the given callback for the specified event.\"\"\"\n self.callbacks[event] = [callback]\n\n def run_callbacks(self, event: str):\n \"\"\"Run all existing callbacks associated with a particular event.\"\"\"\n for callback in self.callbacks.get(event, []):\n callback(self)\n\n def train(self):\n \"\"\"Allow device='', device=None on Multi-GPU systems to default to device=0.\"\"\"\n if isinstance(self.args.device, str) and len(self.args.device): # i.e. device='0' or device='0,1,2,3'\n world_size = len(self.args.device.split(\",\"))\n elif isinstance(self.args.device, (tuple, list)): # i.e. device=[0, 1, 2, 3] (multi-GPU from CLI is list)\n world_size = len(self.args.device)\n elif self.args.device in {\"cpu\", \"mps\"}: # i.e. device='cpu' or 'mps'\n world_size = 0\n elif torch.cuda.is_available(): # i.e. device=None or device='' or device=number\n world_size = 1 # default to device 0\n else: # i.e. device=None or device=''\n world_size = 0\n\n # Run subprocess if DDP training, else train normally\n if world_size > 1 and \"LOCAL_RANK\" not in os.environ:\n # Argument checks\n if self.args.rect:\n LOGGER.warning(\"'rect=True' is incompatible with Multi-GPU training, setting 'rect=False'\")\n self.args.rect = False\n if self.args.batch < 1.0:\n LOGGER.warning(\n \"'batch<1' for AutoBatch is incompatible with Multi-GPU training, setting default 'batch=16'\"\n )\n self.args.batch = 16\n\n # Command\n cmd, file = generate_ddp_command(world_size, self)\n try:\n LOGGER.info(f\"{colorstr('DDP:')} debug command {' '.join(cmd)}\")\n subprocess.run(cmd, check=True)\n except Exception as e:\n raise e\n finally:\n ddp_cleanup(self, str(file))\n\n else:\n self._do_train(world_size)\n\n def _setup_scheduler(self):\n \"\"\"Initialize training learning rate scheduler.\"\"\"\n if self.args.cos_lr:\n self.lf = one_cycle(1, self.args.lrf, self.epochs) # cosine 1->hyp['lrf']\n else:\n self.lf = lambda x: max(1 - x / self.epochs, 0) * (1.0 - self.args.lrf) + self.args.lrf # linear\n self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)\n\n def _setup_ddp(self, world_size):\n \"\"\"Initialize and set the DistributedDataParallel parameters for training.\"\"\"\n torch.cuda.set_device(RANK)\n self.device = torch.device(\"cuda\", RANK)\n # LOGGER.info(f'DDP info: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}')\n os.environ[\"TORCH_NCCL_BLOCKING_WAIT\"] = \"1\" # set to enforce timeout\n dist.init_process_group(\n backend=\"nccl\" if dist.is_nccl_available() else \"gloo\",\n timeout=timedelta(seconds=10800), # 3 hours\n rank=RANK,\n world_size=world_size,\n )\n\n def _setup_train(self, world_size):\n \"\"\"Build dataloaders and optimizer on correct rank process.\"\"\"\n # Model\n self.run_callbacks(\"on_pretrain_routine_start\")\n ckpt = self.setup_model()\n self.model = self.model.to(self.device)\n self.set_model_attributes()\n\n # Freeze layers\n freeze_list = (\n self.args.freeze\n if isinstance(self.args.freeze, list)\n else range(self.args.freeze)\n if isinstance(self.args.freeze, int)\n else []\n )\n always_freeze_names = [\".dfl\"] # always freeze these layers\n freeze_layer_names = [f\"model.{x}.\" for x in freeze_list] + always_freeze_names\n self.freeze_layer_names = freeze_layer_names\n for k, v in self.model.named_parameters():\n # v.register_hook(lambda x: torch.nan_to_num(x)) # NaN to 0 (commented for erratic training results)\n if any(x in k for x in freeze_layer_names):\n LOGGER.info(f\"Freezing layer '{k}'\")\n v.requires_grad = False\n elif not v.requires_grad and v.dtype.is_floating_point: # only floating point Tensor can require gradients\n LOGGER.warning(\n f\"setting 'requires_grad=True' for frozen layer '{k}'. \"\n \"See ultralytics.engine.trainer for customization of frozen layers.\"\n )\n v.requires_grad = True\n\n # Check AMP\n self.amp = torch.tensor(self.args.amp).to(self.device) # True or False\n if self.amp and RANK in {-1, 0}: # Single-GPU and DDP\n callbacks_backup = callbacks.default_callbacks.copy() # backup callbacks as check_amp() resets them\n self.amp = torch.tensor(check_amp(self.model), device=self.device)\n callbacks.default_callbacks = callbacks_backup # restore callbacks\n if RANK > -1 and world_size > 1: # DDP\n dist.broadcast(self.amp.int(), src=0) # broadcast from rank 0 to all other ranks; gloo errors with boolean\n self.amp = bool(self.amp) # as boolean\n self.scaler = (\n torch.amp.GradScaler(\"cuda\", enabled=self.amp) if TORCH_2_4 else torch.cuda.amp.GradScaler(enabled=self.amp)\n )\n if world_size > 1:\n self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[RANK], find_unused_parameters=True)\n\n # Check imgsz\n gs = max(int(self.model.stride.max() if hasattr(self.model, \"stride\") else 32), 32) # grid size (max stride)\n self.args.imgsz = check_imgsz(self.args.imgsz, stride=gs, floor=gs, max_dim=1)\n self.stride = gs # for multiscale training\n\n # Batch size\n if self.batch_size < 1 and RANK == -1: # single-GPU only, estimate best batch size\n self.args.batch = self.batch_size = self.auto_batch()\n\n # Dataloaders\n batch_size = self.batch_size // max(world_size, 1)\n self.train_loader = self.get_dataloader(\n self.data[\"train\"], batch_size=batch_size, rank=LOCAL_RANK, mode=\"train\"\n )\n if RANK in {-1, 0}:\n # Note: When training DOTA dataset, double batch size could get OOM on images with >2000 objects.\n self.test_loader = self.get_dataloader(\n self.data.get(\"val\") or self.data.get(\"test\"),\n batch_size=batch_size if self.args.task == \"obb\" else batch_size * 2,\n rank=-1,\n mode=\"val\",\n )\n self.validator = self.get_validator()\n metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix=\"val\")\n self.metrics = dict(zip(metric_keys, [0] * len(metric_keys)))\n self.ema = ModelEMA(self.model)\n if self.args.plots:\n self.plot_training_labels()\n\n # Optimizer\n self.accumulate = max(round(self.args.nbs / self.batch_size), 1) # accumulate loss before optimizing\n weight_decay = self.args.weight_decay * self.batch_size * self.accumulate / self.args.nbs # scale weight_decay\n iterations = math.ceil(len(self.train_loader.dataset) / max(self.batch_size, self.args.nbs)) * self.epochs\n self.optimizer = self.build_optimizer(\n model=self.model,\n name=self.args.optimizer,\n lr=self.args.lr0,\n momentum=self.args.momentum,\n decay=weight_decay,\n iterations=iterations,\n )\n # Scheduler\n self._setup_scheduler()\n self.stopper, self.stop = EarlyStopping(patience=self.args.patience), False\n self.resume_training(ckpt)\n self.scheduler.last_epoch = self.start_epoch - 1 # do not move\n self.run_callbacks(\"on_pretrain_routine_end\")\n\n def _do_train(self, world_size=1):\n \"\"\"Train the model with the specified world size.\"\"\"\n if world_size > 1:\n self._setup_ddp(world_size)\n self._setup_train(world_size)\n\n nb = len(self.train_loader) # number of batches\n nw = max(round(self.args.warmup_epochs * nb), 100) if self.args.warmup_epochs > 0 else -1 # warmup iterations\n last_opt_step = -1\n self.epoch_time = None\n self.epoch_time_start = time.time()\n self.train_time_start = time.time()\n self.run_callbacks(\"on_train_start\")\n LOGGER.info(\n f\"Image sizes {self.args.imgsz} train, {self.args.imgsz} val\\n\"\n f\"Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\\n\"\n f\"Logging results to {colorstr('bold', self.save_dir)}\\n\"\n f\"Starting training for \" + (f\"{self.args.time} hours...\" if self.args.time else f\"{self.epochs} epochs...\")\n )\n if self.args.close_mosaic:\n base_idx = (self.epochs - self.args.close_mosaic) * nb\n self.plot_idx.extend([base_idx, base_idx + 1, base_idx + 2])\n epoch = self.start_epoch\n self.optimizer.zero_grad() # zero any resumed gradients to ensure stability on train start\n while True:\n self.epoch = epoch\n self.run_callbacks(\"on_train_epoch_start\")\n with warnings.catch_warnings():\n warnings.simplefilter(\"ignore\") # suppress 'Detected lr_scheduler.step() before optimizer.step()'\n self.scheduler.step()\n\n self._model_train()\n if RANK != -1:\n self.train_loader.sampler.set_epoch(epoch)\n pbar = enumerate(self.train_loader)\n # Update dataloader attributes (optional)\n if epoch == (self.epochs - self.args.close_mosaic):\n self._close_dataloader_mosaic()\n self.train_loader.reset()\n\n if RANK in {-1, 0}:\n LOGGER.info(self.progress_string())\n pbar = TQDM(enumerate(self.train_loader), total=nb)\n self.tloss = None\n for i, batch in pbar:\n self.run_callbacks(\"on_train_batch_start\")\n # Warmup\n ni = i + nb * epoch\n if ni <= nw:\n xi = [0, nw] # x interp\n self.accumulate = max(1, int(np.interp(ni, xi, [1, self.args.nbs / self.batch_size]).round()))\n for j, x in enumerate(self.optimizer.param_groups):\n # Bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0\n x[\"lr\"] = np.interp(\n ni, xi, [self.args.warmup_bias_lr if j == 0 else 0.0, x[\"initial_lr\"] * self.lf(epoch)]\n )\n if \"momentum\" in x:\n x[\"momentum\"] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum])\n\n # Forward\n with autocast(self.amp):\n batch = self.preprocess_batch(batch)\n loss, self.loss_items = self.model(batch)\n self.loss = loss.sum()\n if RANK != -1:\n self.loss *= world_size\n self.tloss = (\n (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None else self.loss_items\n )\n\n # Backward\n self.scaler.scale(self.loss).backward()\n\n # Optimize - https://pytorch.org/docs/master/notes/amp_examples.html\n if ni - last_opt_step >= self.accumulate:\n self.optimizer_step()\n last_opt_step = ni\n\n # Timed stopping\n if self.args.time:\n self.stop = (time.time() - self.train_time_start) > (self.args.time * 3600)\n if RANK != -1: # if DDP training\n broadcast_list = [self.stop if RANK == 0 else None]\n dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks\n self.stop = broadcast_list[0]\n if self.stop: # training time exceeded\n break\n\n # Log\n if RANK in {-1, 0}:\n loss_length = self.tloss.shape[0] if len(self.tloss.shape) else 1\n pbar.set_description(\n (\"%11s\" * 2 + \"%11.4g\" * (2 + loss_length))\n % (\n f\"{epoch + 1}/{self.epochs}\",\n f\"{self._get_memory():.3g}G\", # (GB) GPU memory util\n *(self.tloss if loss_length > 1 else torch.unsqueeze(self.tloss, 0)), # losses\n batch[\"cls\"].shape[0], # batch size, i.e. 8\n batch[\"img\"].shape[-1], # imgsz, i.e 640\n )\n )\n self.run_callbacks(\"on_batch_end\")\n if self.args.plots and ni in self.plot_idx:\n self.plot_training_samples(batch, ni)\n\n self.run_callbacks(\"on_train_batch_end\")\n\n self.lr = {f\"lr/pg{ir}\": x[\"lr\"] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers\n self.run_callbacks(\"on_train_epoch_end\")\n if RANK in {-1, 0}:\n final_epoch = epoch + 1 >= self.epochs\n self.ema.update_attr(self.model, include=[\"yaml\", \"nc\", \"args\", \"names\", \"stride\", \"class_weights\"])\n\n # Validation\n if self.args.val or final_epoch or self.stopper.possible_stop or self.stop:\n self._clear_memory(threshold=0.5) # prevent VRAM spike\n self.metrics, self.fitness = self.validate()\n self.save_metrics(metrics={**self.label_loss_items(self.tloss), **self.metrics, **self.lr})\n self.stop |= self.stopper(epoch + 1, self.fitness) or final_epoch\n if self.args.time:\n self.stop |= (time.time() - self.train_time_start) > (self.args.time * 3600)\n\n # Save model\n if self.args.save or final_epoch:\n self.save_model()\n self.run_callbacks(\"on_model_save\")\n\n # Scheduler\n t = time.time()\n self.epoch_time = t - self.epoch_time_start\n self.epoch_time_start = t\n if self.args.time:\n mean_epoch_time = (t - self.train_time_start) / (epoch - self.start_epoch + 1)\n self.epochs = self.args.epochs = math.ceil(self.args.time * 3600 / mean_epoch_time)\n self._setup_scheduler()\n self.scheduler.last_epoch = self.epoch # do not move\n self.stop |= epoch >= self.epochs # stop if exceeded epochs\n self.run_callbacks(\"on_fit_epoch_end\")\n self._clear_memory(0.5) # clear if memory utilization > 50%\n\n # Early Stopping\n if RANK != -1: # if DDP training\n broadcast_list = [self.stop if RANK == 0 else None]\n dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks\n self.stop = broadcast_list[0]\n if self.stop:\n break # must break all DDP ranks\n epoch += 1\n\n if RANK in {-1, 0}:\n # Do final val with best.pt\n seconds = time.time() - self.train_time_start\n LOGGER.info(f\"\\n{epoch - self.start_epoch + 1} epochs completed in {seconds / 3600:.3f} hours.\")\n self.final_eval()\n if self.args.plots:\n self.plot_metrics()\n self.run_callbacks(\"on_train_end\")\n self._clear_memory()\n unset_deterministic()\n self.run_callbacks(\"teardown\")\n\n def auto_batch(self, max_num_obj=0):\n \"\"\"Calculate optimal batch size based on model and device memory constraints.\"\"\"\n return check_train_batch_size(\n model=self.model,\n imgsz=self.args.imgsz,\n amp=self.amp,\n batch=self.batch_size,\n max_num_obj=max_num_obj,\n ) # returns batch size\n\n def _get_memory(self, fraction=False):\n \"\"\"Get accelerator memory utilization in GB or as a fraction of total memory.\"\"\"\n memory, total = 0, 0\n if self.device.type == \"mps\":\n memory = torch.mps.driver_allocated_memory()\n if fraction:\n return __import__(\"psutil\").virtual_memory().percent / 100\n elif self.device.type != \"cpu\":\n memory = torch.cuda.memory_reserved()\n if fraction:\n total = torch.cuda.get_device_properties(self.device).total_memory\n return ((memory / total) if total > 0 else 0) if fraction else (memory / 2**30)\n\n def _clear_memory(self, threshold: float = None):\n \"\"\"Clear accelerator memory by calling garbage collector and emptying cache.\"\"\"\n if threshold:\n assert 0 <= threshold <= 1, \"Threshold must be between 0 and 1.\"\n if self._get_memory(fraction=True) <= threshold:\n return\n gc.collect()\n if self.device.type == \"mps\":\n torch.mps.empty_cache()\n elif self.device.type == \"cpu\":\n return\n else:\n torch.cuda.empty_cache()\n\n def read_results_csv(self):\n \"\"\"Read results.csv into a dictionary using pandas.\"\"\"\n import pandas as pd # scope for faster 'import ultralytics'\n\n return pd.read_csv(self.csv).to_dict(orient=\"list\")\n\n def _model_train(self):\n \"\"\"Set model in training mode.\"\"\"\n self.model.train()\n # Freeze BN stat\n for n, m in self.model.named_modules():\n if any(filter(lambda f: f in n, self.freeze_layer_names)) and isinstance(m, nn.BatchNorm2d):\n m.eval()\n\n def save_model(self):\n \"\"\"Save model training checkpoints with additional metadata.\"\"\"\n import io\n\n # Serialize ckpt to a byte buffer once (faster than repeated torch.save() calls)\n buffer = io.BytesIO()\n torch.save(\n {\n \"epoch\": self.epoch,\n \"best_fitness\": self.best_fitness,\n \"model\": None, # resume and final checkpoints derive from EMA\n \"ema\": deepcopy(self.ema.ema).half(),\n \"updates\": self.ema.updates,\n \"optimizer\": convert_optimizer_state_dict_to_fp16(deepcopy(self.optimizer.state_dict())),\n \"train_args\": vars(self.args), # save as dict\n \"train_metrics\": {**self.metrics, **{\"fitness\": self.fitness}},\n \"train_results\": self.read_results_csv(),\n \"date\": datetime.now().isoformat(),\n \"version\": __version__,\n \"license\": \"AGPL-3.0 (https://ultralytics.com/license)\",\n \"docs\": \"https://docs.ultralytics.com\",\n },\n buffer,\n )\n serialized_ckpt = buffer.getvalue() # get the serialized content to save\n\n # Save checkpoints\n self.last.write_bytes(serialized_ckpt) # save last.pt\n if self.best_fitness == self.fitness:\n self.best.write_bytes(serialized_ckpt) # save best.pt\n if (self.save_period > 0) and (self.epoch % self.save_period == 0):\n (self.wdir / f\"epoch{self.epoch}.pt\").write_bytes(serialized_ckpt) # save epoch, i.e. 'epoch3.pt'\n # if self.args.close_mosaic and self.epoch == (self.epochs - self.args.close_mosaic - 1):\n # (self.wdir / \"last_mosaic.pt\").write_bytes(serialized_ckpt) # save mosaic checkpoint\n\n def get_dataset(self):\n \"\"\"\n Get train and validation datasets from data dictionary.\n\n Returns:\n (dict): A dictionary containing the training/validation/test dataset and category names.\n \"\"\"\n try:\n if self.args.task == \"classify\":\n data = check_cls_dataset(self.args.data)\n elif self.args.data.rsplit(\".\", 1)[-1] in {\"yaml\", \"yml\"} or self.args.task in {\n \"detect\",\n \"segment\",\n \"pose\",\n \"obb\",\n }:\n data = check_det_dataset(self.args.data)\n if \"yaml_file\" in data:\n self.args.data = data[\"yaml_file\"] # for validating 'yolo train data=url.zip' usage\n except Exception as e:\n raise RuntimeError(emojis(f\"Dataset '{clean_url(self.args.data)}' error ❌ {e}\")) from e\n if self.args.single_cls:\n LOGGER.info(\"Overriding class names with single class.\")\n data[\"names\"] = {0: \"item\"}\n data[\"nc\"] = 1\n return data\n\n def setup_model(self):\n \"\"\"\n Load, create, or download model for any task.\n\n Returns:\n (dict): Optional checkpoint to resume training from.\n \"\"\"\n if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed\n return\n\n cfg, weights = self.model, None\n ckpt = None\n if str(self.model).endswith(\".pt\"):\n weights, ckpt = attempt_load_one_weight(self.model)\n cfg = weights.yaml\n elif isinstance(self.args.pretrained, (str, Path)):\n weights, _ = attempt_load_one_weight(self.args.pretrained)\n self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1) # calls Model(cfg, weights)\n return ckpt\n\n def optimizer_step(self):\n \"\"\"Perform a single step of the training optimizer with gradient clipping and EMA update.\"\"\"\n self.scaler.unscale_(self.optimizer) # unscale gradients\n torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0) # clip gradients\n self.scaler.step(self.optimizer)\n self.scaler.update()\n self.optimizer.zero_grad()\n if self.ema:\n self.ema.update(self.model)\n\n def preprocess_batch(self, batch):\n \"\"\"Allow custom preprocessing model inputs and ground truths depending on task type.\"\"\"\n return batch\n\n def validate(self):\n \"\"\"\n Run validation on test set using self.validator.\n\n Returns:\n metrics (dict): Dictionary of validation metrics.\n fitness (float): Fitness score for the validation.\n \"\"\"\n metrics = self.validator(self)\n fitness = metrics.pop(\"fitness\", -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found\n if not self.best_fitness or self.best_fitness < fitness:\n self.best_fitness = fitness\n return metrics, fitness\n\n def get_model(self, cfg=None, weights=None, verbose=True):\n \"\"\"Get model and raise NotImplementedError for loading cfg files.\"\"\"\n raise NotImplementedError(\"This task trainer doesn't support loading cfg files\")\n\n def get_validator(self):\n \"\"\"Return a NotImplementedError when the get_validator function is called.\"\"\"\n raise NotImplementedError(\"get_validator function not implemented in trainer\")\n\n def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode=\"train\"):\n \"\"\"Return dataloader derived from torch.data.Dataloader.\"\"\"\n raise NotImplementedError(\"get_dataloader function not implemented in trainer\")\n\n def build_dataset(self, img_path, mode=\"train\", batch=None):\n \"\"\"Build dataset.\"\"\"\n raise NotImplementedError(\"build_dataset function not implemented in trainer\")\n\n def label_loss_items(self, loss_items=None, prefix=\"train\"):\n \"\"\"\n Return a loss dict with labelled training loss items tensor.\n\n Note:\n This is not needed for classification but necessary for segmentation & detection\n \"\"\"\n return {\"loss\": loss_items} if loss_items is not None else [\"loss\"]\n\n def set_model_attributes(self):\n \"\"\"Set or update model parameters before training.\"\"\"\n self.model.names = self.data[\"names\"]\n\n def build_targets(self, preds, targets):\n \"\"\"Build target tensors for training YOLO model.\"\"\"\n pass\n\n def progress_string(self):\n \"\"\"Return a string describing training progress.\"\"\"\n return \"\"\n\n # TODO: may need to put these following functions into callback\n def plot_training_samples(self, batch, ni):\n \"\"\"Plot training samples during YOLO training.\"\"\"\n pass\n\n def plot_training_labels(self):\n \"\"\"Plot training labels for YOLO model.\"\"\"\n pass\n\n def save_metrics(self, metrics):\n \"\"\"Save training metrics to a CSV file.\"\"\"\n keys, vals = list(metrics.keys()), list(metrics.values())\n n = len(metrics) + 2 # number of cols\n s = \"\" if self.csv.exists() else ((\"%s,\" * n % tuple([\"epoch\", \"time\"] + keys)).rstrip(\",\") + \"\\n\") # header\n t = time.time() - self.train_time_start\n with open(self.csv, \"a\", encoding=\"utf-8\") as f:\n f.write(s + (\"%.6g,\" * n % tuple([self.epoch + 1, t] + vals)).rstrip(\",\") + \"\\n\")\n\n def plot_metrics(self):\n \"\"\"Plot and display metrics visually.\"\"\"\n pass\n\n def on_plot(self, name, data=None):\n \"\"\"Register plots (e.g. to be consumed in callbacks).\"\"\"\n path = Path(name)\n self.plots[path] = {\"data\": data, \"timestamp\": time.time()}\n\n def final_eval(self):\n \"\"\"Perform final evaluation and validation for object detection YOLO model.\"\"\"\n ckpt = {}\n for f in self.last, self.best:\n if f.exists():\n if f is self.last:\n ckpt = strip_optimizer(f)\n elif f is self.best:\n k = \"train_results\" # update best.pt train_metrics from last.pt\n strip_optimizer(f, updates={k: ckpt[k]} if k in ckpt else None)\n LOGGER.info(f\"\\nValidating {f}...\")\n self.validator.args.plots = self.args.plots\n self.metrics = self.validator(model=f)\n self.metrics.pop(\"fitness\", None)\n self.run_callbacks(\"on_fit_epoch_end\")\n\n def check_resume(self, overrides):\n \"\"\"Check if resume checkpoint exists and update arguments accordingly.\"\"\"\n resume = self.args.resume\n if resume:\n try:\n exists = isinstance(resume, (str, Path)) and Path(resume).exists()\n last = Path(check_file(resume) if exists else get_latest_run())\n\n # Check that resume data YAML exists, otherwise strip to force re-download of dataset\n ckpt_args = attempt_load_weights(last).args\n if not isinstance(ckpt_args[\"data\"], dict) and not Path(ckpt_args[\"data\"]).exists():\n ckpt_args[\"data\"] = self.args.data\n\n resume = True\n self.args = get_cfg(ckpt_args)\n self.args.model = self.args.resume = str(last) # reinstate model\n for k in (\n \"imgsz\",\n \"batch\",\n \"device\",\n \"close_mosaic\",\n ): # allow arg updates to reduce memory or update device on resume\n if k in overrides:\n setattr(self.args, k, overrides[k])\n\n except Exception as e:\n raise FileNotFoundError(\n \"Resume checkpoint not found. Please pass a valid checkpoint to resume from, \"\n \"i.e. 'yolo train resume model=path/to/last.pt'\"\n ) from e\n self.resume = resume\n\n def resume_training(self, ckpt):\n \"\"\"Resume YOLO training from given epoch and best fitness.\"\"\"\n if ckpt is None or not self.resume:\n return\n best_fitness = 0.0\n start_epoch = ckpt.get(\"epoch\", -1) + 1\n if ckpt.get(\"optimizer\", None) is not None:\n self.optimizer.load_state_dict(ckpt[\"optimizer\"]) # optimizer\n best_fitness = ckpt[\"best_fitness\"]\n if self.ema and ckpt.get(\"ema\"):\n self.ema.ema.load_state_dict(ckpt[\"ema\"].float().state_dict()) # EMA\n self.ema.updates = ckpt[\"updates\"]\n assert start_epoch > 0, (\n f\"{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\\n\"\n f\"Start a new training without resuming, i.e. 'yolo train model={self.args.model}'\"\n )\n LOGGER.info(f\"Resuming training {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs\")\n if self.epochs < start_epoch:\n LOGGER.info(\n f\"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs.\"\n )\n self.epochs += ckpt[\"epoch\"] # finetune additional epochs\n self.best_fitness = best_fitness\n self.start_epoch = start_epoch\n if start_epoch > (self.epochs - self.args.close_mosaic):\n self._close_dataloader_mosaic()\n\n def _close_dataloader_mosaic(self):\n \"\"\"Update dataloaders to stop using mosaic augmentation.\"\"\"\n if hasattr(self.train_loader.dataset, \"mosaic\"):\n self.train_loader.dataset.mosaic = False\n if hasattr(self.train_loader.dataset, \"close_mosaic\"):\n LOGGER.info(\"Closing dataloader mosaic\")\n self.train_loader.dataset.close_mosaic(hyp=copy(self.args))\n\n def build_optimizer(self, model, name=\"auto\", lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5):\n \"\"\"\n Construct an optimizer for the given model.\n\n Args:\n model (torch.nn.Module): The model for which to build an optimizer.\n name (str, optional): The name of the optimizer to use. If 'auto', the optimizer is selected\n based on the number of iterations.\n lr (float, optional): The learning rate for the optimizer.\n momentum (float, optional): The momentum factor for the optimizer.\n decay (float, optional): The weight decay for the optimizer.\n iterations (float, optional): The number of iterations, which determines the optimizer if\n name is 'auto'.\n\n Returns:\n (torch.optim.Optimizer): The constructed optimizer.\n \"\"\"\n g = [], [], [] # optimizer parameter groups\n bn = tuple(v for k, v in nn.__dict__.items() if \"Norm\" in k) # normalization layers, i.e. BatchNorm2d()\n if name == \"auto\":\n LOGGER.info(\n f\"{colorstr('optimizer:')} 'optimizer=auto' found, \"\n f\"ignoring 'lr0={self.args.lr0}' and 'momentum={self.args.momentum}' and \"\n f\"determining best 'optimizer', 'lr0' and 'momentum' automatically... \"\n )\n nc = self.data.get(\"nc\", 10) # number of classes\n lr_fit = round(0.002 * 5 / (4 + nc), 6) # lr0 fit equation to 6 decimal places\n name, lr, momentum = (\"SGD\", 0.01, 0.9) if iterations > 10000 else (\"AdamW\", lr_fit, 0.9)\n self.args.warmup_bias_lr = 0.0 # no higher than 0.01 for Adam\n\n for module_name, module in model.named_modules():\n for param_name, param in module.named_parameters(recurse=False):\n fullname = f\"{module_name}.{param_name}\" if module_name else param_name\n if \"bias\" in fullname: # bias (no decay)\n g[2].append(param)\n elif isinstance(module, bn) or \"logit_scale\" in fullname: # weight (no decay)\n # ContrastiveHead and BNContrastiveHead included here with 'logit_scale'\n g[1].append(param)\n else: # weight (with decay)\n g[0].append(param)\n\n optimizers = {\"Adam\", \"Adamax\", \"AdamW\", \"NAdam\", \"RAdam\", \"RMSProp\", \"SGD\", \"auto\"}\n name = {x.lower(): x for x in optimizers}.get(name.lower())\n if name in {\"Adam\", \"Adamax\", \"AdamW\", \"NAdam\", \"RAdam\"}:\n optimizer = getattr(optim, name, optim.Adam)(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0)\n elif name == \"RMSProp\":\n optimizer = optim.RMSprop(g[2], lr=lr, momentum=momentum)\n elif name == \"SGD\":\n optimizer = optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True)\n else:\n raise NotImplementedError(\n f\"Optimizer '{name}' not found in list of available optimizers {optimizers}. \"\n \"Request support for addition optimizers at https://github.com/ultralytics/ultralytics.\"\n )\n\n optimizer.add_param_group({\"params\": g[0], \"weight_decay\": decay}) # add g0 with weight_decay\n optimizer.add_param_group({\"params\": g[1], \"weight_decay\": 0.0}) # add g1 (BatchNorm2d weights)\n LOGGER.info(\n f\"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}, momentum={momentum}) with parameter groups \"\n f\"{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)\"\n )\n return optimizer", "chunk_type": "class", "name": "BaseTrainer", "file_path": "ultralytics\\ultralytics\\engine\\trainer.py", "start_line": 59, "end_line": 874, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": "A base class for creating trainers.\n\nThis class provides the foundation for training YOLO models, handling the training loop, validation, checkpointing,\nand various training utilities. It supports both single-GPU and multi-GPU distributed training.\n\nAttributes:\n args (SimpleNamespace): Configuration for the trainer.\n validator (BaseValidator): Validator instance.\n model (nn.Module): Model instance.\n callbacks (defaultdict): Dictionary of callbacks.\n save_dir (Path): Directory to save results.\n wdir (Path): Directory to save weights.\n last (Path): Path to the last checkpoint.\n best (Path): Path to the best checkpoint.\n save_period (int): Save checkpoint every x epochs (disabled if < 1).\n batch_size (int): Batch size for training.\n epochs (int): Number of epochs to train for.\n start_epoch (int): Starting epoch for training.\n device (torch.device): Device to use for training.\n amp (bool): Flag to enable AMP (Automatic Mixed Precision).\n scaler (amp.GradScaler): Gradient scaler for AMP.\n data (str): Path to data.\n ema (nn.Module): EMA (Exponential Moving Average) of the model.\n resume (bool): Resume training from a checkpoint.\n lf (nn.Module): Loss function.\n scheduler (torch.optim.lr_scheduler._LRScheduler): Learning rate scheduler.\n best_fitness (float): The best fitness value achieved.\n fitness (float): Current fitness value.\n loss (float): Current loss value.\n tloss (float): Total loss value.\n loss_names (list): List of loss names.\n csv (Path): Path to results CSV file.\n metrics (dict): Dictionary of metrics.\n plots (dict): Dictionary of plots.\n\nMethods:\n train: Execute the training process.\n validate: Run validation on the test set.\n save_model: Save model training checkpoints.\n get_dataset: Get train and validation datasets.\n setup_model: Load, create, or download model.\n build_optimizer: Construct an optimizer for the model.\n\nExamples:\n Initialize a trainer and start training\n >>> trainer = BaseTrainer(cfg=\"config.yaml\")\n >>> trainer.train()", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "gc", "math", "os", "subprocess", "time", "warnings", "copy.copy", "copy.deepcopy", "datetime.datetime", "datetime.timedelta", "pathlib.Path", "numpy", "torch", "torch.distributed", "torch.nn", "torch.optim", "ultralytics.__version__", "ultralytics.cfg.get_cfg", "ultralytics.cfg.get_save_dir", "ultralytics.data.utils.check_cls_dataset", "ultralytics.data.utils.check_det_dataset", "ultralytics.nn.tasks.attempt_load_one_weight", "ultralytics.nn.tasks.attempt_load_weights", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.LOCAL_RANK", "ultralytics.utils.LOGGER", "ultralytics.utils.RANK", "ultralytics.utils.TQDM", "ultralytics.utils.YAML", "ultralytics.utils.callbacks", "ultralytics.utils.clean_url", "ultralytics.utils.colorstr", "ultralytics.utils.emojis", "ultralytics.utils.autobatch.check_train_batch_size", "ultralytics.utils.checks.check_amp", "ultralytics.utils.checks.check_file", "ultralytics.utils.checks.check_imgsz", "ultralytics.utils.checks.check_model_file_from_stem", "ultralytics.utils.checks.print_args", "ultralytics.utils.dist.ddp_cleanup", "ultralytics.utils.dist.generate_ddp_command", "ultralytics.utils.files.get_latest_run", "ultralytics.utils.torch_utils.TORCH_2_4", "ultralytics.utils.torch_utils.EarlyStopping", "ultralytics.utils.torch_utils.ModelEMA", "ultralytics.utils.torch_utils.autocast", "ultralytics.utils.torch_utils.convert_optimizer_state_dict_to_fp16", "ultralytics.utils.torch_utils.init_seeds", "ultralytics.utils.torch_utils.one_cycle", "ultralytics.utils.torch_utils.select_device", "ultralytics.utils.torch_utils.strip_optimizer", "ultralytics.utils.torch_utils.torch_distributed_zero_first", "ultralytics.utils.torch_utils.unset_deterministic", "pandas", "io" ], "chunk_id": "class_BaseTrainer_0c48e0c7" }, { "content": "import random", "chunk_type": "import", "name": "random", "file_path": "ultralytics\\ultralytics\\engine\\tuner.py", "start_line": 17, "end_line": 17, "start_col": 0, "end_col": 13, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_random_7b5eb9d7" }, { "content": "import shutil", "chunk_type": "import", "name": "shutil", "file_path": "ultralytics\\ultralytics\\engine\\tuner.py", "start_line": 18, "end_line": 18, "start_col": 0, "end_col": 13, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_shutil_39efc685" }, { "content": "import subprocess", "chunk_type": "import", "name": "subprocess", "file_path": "ultralytics\\ultralytics\\engine\\tuner.py", "start_line": 19, "end_line": 19, "start_col": 0, "end_col": 17, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_subprocess_faa43541" }, { "content": "import time", "chunk_type": "import", "name": "time", "file_path": "ultralytics\\ultralytics\\engine\\tuner.py", "start_line": 20, "end_line": 20, "start_col": 0, "end_col": 11, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_time_2ef53617" }, { "content": "from typing import Dict, List, Optional", "chunk_type": "import", "name": "Dict, List, Optional", "file_path": "ultralytics\\ultralytics\\engine\\tuner.py", "start_line": 21, "end_line": 21, "start_col": 0, "end_col": 39, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Dict, List, Optional_1e42f284" }, { "content": "import numpy as np", "chunk_type": "import", "name": "numpy", "file_path": "ultralytics\\ultralytics\\engine\\tuner.py", "start_line": 23, "end_line": 23, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_numpy_02cb1ae1" }, { "content": "from ultralytics.cfg import get_cfg, get_save_dir", "chunk_type": "import", "name": "get_cfg, get_save_dir", "file_path": "ultralytics\\ultralytics\\engine\\tuner.py", "start_line": 25, "end_line": 25, "start_col": 0, "end_col": 49, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_get_cfg, get_save_dir_e2fdabd0" }, { "content": "from ultralytics.utils import DEFAULT_CFG, LOGGER, YAML, callbacks, colorstr, remove_colorstr", "chunk_type": "import", "name": "DEFAULT_CFG, LOGGER, YAML, callbacks, colorstr, remove_colorstr", "file_path": "ultralytics\\ultralytics\\engine\\tuner.py", "start_line": 26, "end_line": 26, "start_col": 0, "end_col": 93, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_DEFAULT_CFG, LOGGER, YAML, callbacks, colorstr, remove_colorstr_5d6da710" }, { "content": "from ultralytics.utils.patches import torch_load", "chunk_type": "import", "name": "torch_load", "file_path": "ultralytics\\ultralytics\\engine\\tuner.py", "start_line": 27, "end_line": 27, "start_col": 0, "end_col": 48, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_load_e9c82eb1" }, { "content": "from ultralytics.utils.plotting import plot_tune_results", "chunk_type": "import", "name": "plot_tune_results", "file_path": "ultralytics\\ultralytics\\engine\\tuner.py", "start_line": 28, "end_line": 28, "start_col": 0, "end_col": 56, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_plot_tune_results_8d6321ea" }, { "content": "class Tuner:\n \"\"\"\n A class for hyperparameter tuning of YOLO models.\n\n The class evolves YOLO model hyperparameters over a given number of iterations by mutating them according to the\n search space and retraining the model to evaluate their performance.\n\n Attributes:\n space (Dict[str, tuple]): Hyperparameter search space containing bounds and scaling factors for mutation.\n tune_dir (Path): Directory where evolution logs and results will be saved.\n tune_csv (Path): Path to the CSV file where evolution logs are saved.\n args (dict): Configuration arguments for the tuning process.\n callbacks (list): Callback functions to be executed during tuning.\n prefix (str): Prefix string for logging messages.\n\n Methods:\n _mutate: Mutate hyperparameters based on bounds and scaling factors.\n __call__: Execute the hyperparameter evolution across multiple iterations.\n\n Examples:\n Tune hyperparameters for YOLO11n on COCO8 at imgsz=640 and epochs=30 for 300 tuning iterations.\n >>> from ultralytics import YOLO\n >>> model = YOLO(\"yolo11n.pt\")\n >>> model.tune(\n ... data=\"coco8.yaml\", epochs=10, iterations=300, optimizer=\"AdamW\", plots=False, save=False, val=False\n ... )\n\n Tune with custom search space.\n >>> model.tune(space={key1: val1, key2: val2}) # custom search space dictionary\n \"\"\"\n\n def __init__(self, args=DEFAULT_CFG, _callbacks: Optional[List] = None):\n \"\"\"\n Initialize the Tuner with configurations.\n\n Args:\n args (dict): Configuration for hyperparameter evolution.\n _callbacks (List, optional): Callback functions to be executed during tuning.\n \"\"\"\n self.space = args.pop(\"space\", None) or { # key: (min, max, gain(optional))\n # 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']),\n \"lr0\": (1e-5, 1e-1), # initial learning rate (i.e. SGD=1E-2, Adam=1E-3)\n \"lrf\": (0.0001, 0.1), # final OneCycleLR learning rate (lr0 * lrf)\n \"momentum\": (0.7, 0.98, 0.3), # SGD momentum/Adam beta1\n \"weight_decay\": (0.0, 0.001), # optimizer weight decay 5e-4\n \"warmup_epochs\": (0.0, 5.0), # warmup epochs (fractions ok)\n \"warmup_momentum\": (0.0, 0.95), # warmup initial momentum\n \"box\": (1.0, 20.0), # box loss gain\n \"cls\": (0.2, 4.0), # cls loss gain (scale with pixels)\n \"dfl\": (0.4, 6.0), # dfl loss gain\n \"hsv_h\": (0.0, 0.1), # image HSV-Hue augmentation (fraction)\n \"hsv_s\": (0.0, 0.9), # image HSV-Saturation augmentation (fraction)\n \"hsv_v\": (0.0, 0.9), # image HSV-Value augmentation (fraction)\n \"degrees\": (0.0, 45.0), # image rotation (+/- deg)\n \"translate\": (0.0, 0.9), # image translation (+/- fraction)\n \"scale\": (0.0, 0.95), # image scale (+/- gain)\n \"shear\": (0.0, 10.0), # image shear (+/- deg)\n \"perspective\": (0.0, 0.001), # image perspective (+/- fraction), range 0-0.001\n \"flipud\": (0.0, 1.0), # image flip up-down (probability)\n \"fliplr\": (0.0, 1.0), # image flip left-right (probability)\n \"bgr\": (0.0, 1.0), # image channel bgr (probability)\n \"mosaic\": (0.0, 1.0), # image mosaic (probability)\n \"mixup\": (0.0, 1.0), # image mixup (probability)\n \"cutmix\": (0.0, 1.0), # image cutmix (probability)\n \"copy_paste\": (0.0, 1.0), # segment copy-paste (probability)\n }\n self.args = get_cfg(overrides=args)\n self.args.exist_ok = self.args.resume # resume w/ same tune_dir\n self.tune_dir = get_save_dir(self.args, name=self.args.name or \"tune\")\n self.args.name, self.args.exist_ok, self.args.resume = (None, False, False) # reset to not affect training\n self.tune_csv = self.tune_dir / \"tune_results.csv\"\n self.callbacks = _callbacks or callbacks.get_default_callbacks()\n self.prefix = colorstr(\"Tuner: \")\n callbacks.add_integration_callbacks(self)\n LOGGER.info(\n f\"{self.prefix}Initialized Tuner instance with 'tune_dir={self.tune_dir}'\\n\"\n f\"{self.prefix}💡 Learn about tuning at https://docs.ultralytics.com/guides/hyperparameter-tuning\"\n )\n\n def _mutate(\n self, parent: str = \"single\", n: int = 5, mutation: float = 0.8, sigma: float = 0.2\n ) -> Dict[str, float]:\n \"\"\"\n Mutate hyperparameters based on bounds and scaling factors specified in `self.space`.\n\n Args:\n parent (str): Parent selection method: 'single' or 'weighted'.\n n (int): Number of parents to consider.\n mutation (float): Probability of a parameter mutation in any given iteration.\n sigma (float): Standard deviation for Gaussian random number generator.\n\n Returns:\n (Dict[str, float]): A dictionary containing mutated hyperparameters.\n \"\"\"\n if self.tune_csv.exists(): # if CSV file exists: select best hyps and mutate\n # Select parent(s)\n x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=\",\", skiprows=1)\n fitness = x[:, 0] # first column\n n = min(n, len(x)) # number of previous results to consider\n x = x[np.argsort(-fitness)][:n] # top n mutations\n w = x[:, 0] - x[:, 0].min() + 1e-6 # weights (sum > 0)\n if parent == \"single\" or len(x) == 1:\n # x = x[random.randint(0, n - 1)] # random selection\n x = x[random.choices(range(n), weights=w)[0]] # weighted selection\n elif parent == \"weighted\":\n x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination\n\n # Mutate\n r = np.random # method\n r.seed(int(time.time()))\n g = np.array([v[2] if len(v) == 3 else 1.0 for v in self.space.values()]) # gains 0-1\n ng = len(self.space)\n v = np.ones(ng)\n while all(v == 1): # mutate until a change occurs (prevent duplicates)\n v = (g * (r.random(ng) < mutation) * r.randn(ng) * r.random() * sigma + 1).clip(0.3, 3.0)\n hyp = {k: float(x[i + 1] * v[i]) for i, k in enumerate(self.space.keys())}\n else:\n hyp = {k: getattr(self.args, k) for k in self.space.keys()}\n\n # Constrain to limits\n for k, v in self.space.items():\n hyp[k] = max(hyp[k], v[0]) # lower limit\n hyp[k] = min(hyp[k], v[1]) # upper limit\n hyp[k] = round(hyp[k], 5) # significant digits\n\n return hyp\n\n def __call__(self, model=None, iterations: int = 10, cleanup: bool = True):\n \"\"\"\n Execute the hyperparameter evolution process when the Tuner instance is called.\n\n This method iterates through the number of iterations, performing the following steps in each iteration:\n\n 1. Load the existing hyperparameters or initialize new ones.\n 2. Mutate the hyperparameters using the `_mutate` method.\n 3. Train a YOLO model with the mutated hyperparameters.\n 4. Log the fitness score and mutated hyperparameters to a CSV file.\n\n Args:\n model (Model): A pre-initialized YOLO model to be used for training.\n iterations (int): The number of generations to run the evolution for.\n cleanup (bool): Whether to delete iteration weights to reduce storage space used during tuning.\n\n Note:\n The method utilizes the `self.tune_csv` Path object to read and log hyperparameters and fitness scores.\n Ensure this path is set correctly in the Tuner instance.\n \"\"\"\n t0 = time.time()\n best_save_dir, best_metrics = None, None\n (self.tune_dir / \"weights\").mkdir(parents=True, exist_ok=True)\n start = 0\n if self.tune_csv.exists():\n x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=\",\", skiprows=1)\n start = x.shape[0]\n LOGGER.info(f\"{self.prefix}Resuming tuning run {self.tune_dir} from iteration {start + 1}...\")\n for i in range(start, iterations):\n # Mutate hyperparameters\n mutated_hyp = self._mutate()\n LOGGER.info(f\"{self.prefix}Starting iteration {i + 1}/{iterations} with hyperparameters: {mutated_hyp}\")\n\n metrics = {}\n train_args = {**vars(self.args), **mutated_hyp}\n save_dir = get_save_dir(get_cfg(train_args))\n weights_dir = save_dir / \"weights\"\n try:\n # Train YOLO model with mutated hyperparameters (run in subprocess to avoid dataloader hang)\n launch = [__import__(\"sys\").executable, \"-m\", \"ultralytics.cfg.__init__\"] # workaround yolo not found\n cmd = [*launch, \"train\", *(f\"{k}={v}\" for k, v in train_args.items())]\n return_code = subprocess.run(cmd, check=True).returncode\n ckpt_file = weights_dir / (\"best.pt\" if (weights_dir / \"best.pt\").exists() else \"last.pt\")\n metrics = torch_load(ckpt_file)[\"train_metrics\"]\n assert return_code == 0, \"training failed\"\n\n except Exception as e:\n LOGGER.error(f\"training failure for hyperparameter tuning iteration {i + 1}\\n{e}\")\n\n # Save results and mutated_hyp to CSV\n fitness = metrics.get(\"fitness\", 0.0)\n log_row = [round(fitness, 5)] + [mutated_hyp[k] for k in self.space.keys()]\n headers = \"\" if self.tune_csv.exists() else (\",\".join([\"fitness\"] + list(self.space.keys())) + \"\\n\")\n with open(self.tune_csv, \"a\", encoding=\"utf-8\") as f:\n f.write(headers + \",\".join(map(str, log_row)) + \"\\n\")\n\n # Get best results\n x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=\",\", skiprows=1)\n fitness = x[:, 0] # first column\n best_idx = fitness.argmax()\n best_is_current = best_idx == i\n if best_is_current:\n best_save_dir = save_dir\n best_metrics = {k: round(v, 5) for k, v in metrics.items()}\n for ckpt in weights_dir.glob(\"*.pt\"):\n shutil.copy2(ckpt, self.tune_dir / \"weights\")\n elif cleanup:\n shutil.rmtree(weights_dir, ignore_errors=True) # remove iteration weights/ dir to reduce storage space\n\n # Plot tune results\n plot_tune_results(self.tune_csv)\n\n # Save and print tune results\n header = (\n f\"{self.prefix}{i + 1}/{iterations} iterations complete ✅ ({time.time() - t0:.2f}s)\\n\"\n f\"{self.prefix}Results saved to {colorstr('bold', self.tune_dir)}\\n\"\n f\"{self.prefix}Best fitness={fitness[best_idx]} observed at iteration {best_idx + 1}\\n\"\n f\"{self.prefix}Best fitness metrics are {best_metrics}\\n\"\n f\"{self.prefix}Best fitness model is {best_save_dir}\\n\"\n f\"{self.prefix}Best fitness hyperparameters are printed below.\\n\"\n )\n LOGGER.info(\"\\n\" + header)\n data = {k: float(x[best_idx, i + 1]) for i, k in enumerate(self.space.keys())}\n YAML.save(\n self.tune_dir / \"best_hyperparameters.yaml\",\n data=data,\n header=remove_colorstr(header.replace(self.prefix, \"# \")) + \"\\n\",\n )\n YAML.print(self.tune_dir / \"best_hyperparameters.yaml\")", "chunk_type": "class", "name": "Tuner", "file_path": "ultralytics\\ultralytics\\engine\\tuner.py", "start_line": 31, "end_line": 246, "start_col": 0, "end_col": 67, "parent_name": null, "docstring": "A class for hyperparameter tuning of YOLO models.\n\nThe class evolves YOLO model hyperparameters over a given number of iterations by mutating them according to the\nsearch space and retraining the model to evaluate their performance.\n\nAttributes:\n space (Dict[str, tuple]): Hyperparameter search space containing bounds and scaling factors for mutation.\n tune_dir (Path): Directory where evolution logs and results will be saved.\n tune_csv (Path): Path to the CSV file where evolution logs are saved.\n args (dict): Configuration arguments for the tuning process.\n callbacks (list): Callback functions to be executed during tuning.\n prefix (str): Prefix string for logging messages.\n\nMethods:\n _mutate: Mutate hyperparameters based on bounds and scaling factors.\n __call__: Execute the hyperparameter evolution across multiple iterations.\n\nExamples:\n Tune hyperparameters for YOLO11n on COCO8 at imgsz=640 and epochs=30 for 300 tuning iterations.\n >>> from ultralytics import YOLO\n >>> model = YOLO(\"yolo11n.pt\")\n >>> model.tune(\n ... data=\"coco8.yaml\", epochs=10, iterations=300, optimizer=\"AdamW\", plots=False, save=False, val=False\n ... )\n\n Tune with custom search space.\n >>> model.tune(space={key1: val1, key2: val2}) # custom search space dictionary", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "random", "shutil", "subprocess", "time", "typing.Dict", "typing.List", "typing.Optional", "numpy", "ultralytics.cfg.get_cfg", "ultralytics.cfg.get_save_dir", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.LOGGER", "ultralytics.utils.YAML", "ultralytics.utils.callbacks", "ultralytics.utils.colorstr", "ultralytics.utils.remove_colorstr", "ultralytics.utils.patches.torch_load", "ultralytics.utils.plotting.plot_tune_results" ], "chunk_id": "class_Tuner_1f6ca3f7" }, { "content": "import json", "chunk_type": "import", "name": "json", "file_path": "ultralytics\\ultralytics\\engine\\validator.py", "start_line": 26, "end_line": 26, "start_col": 0, "end_col": 11, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_json_9a9f848d" }, { "content": "import time", "chunk_type": "import", "name": "time", "file_path": "ultralytics\\ultralytics\\engine\\validator.py", "start_line": 27, "end_line": 27, "start_col": 0, "end_col": 11, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_time_c5fcc7e7" }, { "content": "from pathlib import Path", "chunk_type": "import", "name": "Path", "file_path": "ultralytics\\ultralytics\\engine\\validator.py", "start_line": 28, "end_line": 28, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Path_ee4f67a8" }, { "content": "import numpy as np", "chunk_type": "import", "name": "numpy", "file_path": "ultralytics\\ultralytics\\engine\\validator.py", "start_line": 30, "end_line": 30, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_numpy_4d03ea28" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\ultralytics\\engine\\validator.py", "start_line": 31, "end_line": 31, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_c1e7c9c3" }, { "content": "from ultralytics.cfg import get_cfg, get_save_dir", "chunk_type": "import", "name": "get_cfg, get_save_dir", "file_path": "ultralytics\\ultralytics\\engine\\validator.py", "start_line": 33, "end_line": 33, "start_col": 0, "end_col": 49, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_get_cfg, get_save_dir_a1153613" }, { "content": "from ultralytics.data.utils import check_cls_dataset, check_det_dataset", "chunk_type": "import", "name": "check_cls_dataset, check_det_dataset", "file_path": "ultralytics\\ultralytics\\engine\\validator.py", "start_line": 34, "end_line": 34, "start_col": 0, "end_col": 71, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_check_cls_dataset, check_det_dataset_323b47c0" }, { "content": "from ultralytics.nn.autobackend import AutoBackend", "chunk_type": "import", "name": "AutoBackend", "file_path": "ultralytics\\ultralytics\\engine\\validator.py", "start_line": 35, "end_line": 35, "start_col": 0, "end_col": 50, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_AutoBackend_7e0849c0" }, { "content": "from ultralytics.utils import LOGGER, TQDM, callbacks, colorstr, emojis", "chunk_type": "import", "name": "LOGGER, TQDM, callbacks, colorstr, emojis", "file_path": "ultralytics\\ultralytics\\engine\\validator.py", "start_line": 36, "end_line": 36, "start_col": 0, "end_col": 71, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_LOGGER, TQDM, callbacks, colorstr, emojis_d65694f7" }, { "content": "from ultralytics.utils.checks import check_imgsz", "chunk_type": "import", "name": "check_imgsz", "file_path": "ultralytics\\ultralytics\\engine\\validator.py", "start_line": 37, "end_line": 37, "start_col": 0, "end_col": 48, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_check_imgsz_71010a9b" }, { "content": "from ultralytics.utils.ops import Profile", "chunk_type": "import", "name": "Profile", "file_path": "ultralytics\\ultralytics\\engine\\validator.py", "start_line": 38, "end_line": 38, "start_col": 0, "end_col": 41, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Profile_7025a3f1" }, { "content": "from ultralytics.utils.torch_utils import de_parallel, select_device, smart_inference_mode", "chunk_type": "import", "name": "de_parallel, select_device, smart_inference_mode", "file_path": "ultralytics\\ultralytics\\engine\\validator.py", "start_line": 39, "end_line": 39, "start_col": 0, "end_col": 90, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_de_parallel, select_device, smart_inference_mode_1216d6a5" }, { "content": "class BaseValidator:\n \"\"\"\n A base class for creating validators.\n\n This class provides the foundation for validation processes, including model evaluation, metric computation, and\n result visualization.\n\n Attributes:\n args (SimpleNamespace): Configuration for the validator.\n dataloader (DataLoader): Dataloader to use for validation.\n model (nn.Module): Model to validate.\n data (dict): Data dictionary containing dataset information.\n device (torch.device): Device to use for validation.\n batch_i (int): Current batch index.\n training (bool): Whether the model is in training mode.\n names (dict): Class names mapping.\n seen (int): Number of images seen so far during validation.\n stats (dict): Statistics collected during validation.\n confusion_matrix: Confusion matrix for classification evaluation.\n nc (int): Number of classes.\n iouv (torch.Tensor): IoU thresholds from 0.50 to 0.95 in spaces of 0.05.\n jdict (list): List to store JSON validation results.\n speed (dict): Dictionary with keys 'preprocess', 'inference', 'loss', 'postprocess' and their respective\n batch processing times in milliseconds.\n save_dir (Path): Directory to save results.\n plots (dict): Dictionary to store plots for visualization.\n callbacks (dict): Dictionary to store various callback functions.\n stride (int): Model stride for padding calculations.\n loss (torch.Tensor): Accumulated loss during training validation.\n\n Methods:\n __call__: Execute validation process, running inference on dataloader and computing performance metrics.\n match_predictions: Match predictions to ground truth objects using IoU.\n add_callback: Append the given callback to the specified event.\n run_callbacks: Run all callbacks associated with a specified event.\n get_dataloader: Get data loader from dataset path and batch size.\n build_dataset: Build dataset from image path.\n preprocess: Preprocess an input batch.\n postprocess: Postprocess the predictions.\n init_metrics: Initialize performance metrics for the YOLO model.\n update_metrics: Update metrics based on predictions and batch.\n finalize_metrics: Finalize and return all metrics.\n get_stats: Return statistics about the model's performance.\n print_results: Print the results of the model's predictions.\n get_desc: Get description of the YOLO model.\n on_plot: Register plots for visualization.\n plot_val_samples: Plot validation samples during training.\n plot_predictions: Plot YOLO model predictions on batch images.\n pred_to_json: Convert predictions to JSON format.\n eval_json: Evaluate and return JSON format of prediction statistics.\n \"\"\"\n\n def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None):\n \"\"\"\n Initialize a BaseValidator instance.\n\n Args:\n dataloader (torch.utils.data.DataLoader, optional): Dataloader to be used for validation.\n save_dir (Path, optional): Directory to save results.\n args (SimpleNamespace, optional): Configuration for the validator.\n _callbacks (dict, optional): Dictionary to store various callback functions.\n \"\"\"\n self.args = get_cfg(overrides=args)\n self.dataloader = dataloader\n self.stride = None\n self.data = None\n self.device = None\n self.batch_i = None\n self.training = True\n self.names = None\n self.seen = None\n self.stats = None\n self.confusion_matrix = None\n self.nc = None\n self.iouv = None\n self.jdict = None\n self.speed = {\"preprocess\": 0.0, \"inference\": 0.0, \"loss\": 0.0, \"postprocess\": 0.0}\n\n self.save_dir = save_dir or get_save_dir(self.args)\n (self.save_dir / \"labels\" if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)\n if self.args.conf is None:\n self.args.conf = 0.01 if self.args.task == \"obb\" else 0.001 # reduce OBB val memory usage\n self.args.imgsz = check_imgsz(self.args.imgsz, max_dim=1)\n\n self.plots = {}\n self.callbacks = _callbacks or callbacks.get_default_callbacks()\n\n @smart_inference_mode()\n def __call__(self, trainer=None, model=None):\n \"\"\"\n Execute validation process, running inference on dataloader and computing performance metrics.\n\n Args:\n trainer (object, optional): Trainer object that contains the model to validate.\n model (nn.Module, optional): Model to validate if not using a trainer.\n\n Returns:\n (dict): Dictionary containing validation statistics.\n \"\"\"\n self.training = trainer is not None\n augment = self.args.augment and (not self.training)\n if self.training:\n self.device = trainer.device\n self.data = trainer.data\n # Force FP16 val during training\n self.args.half = self.device.type != \"cpu\" and trainer.amp\n model = trainer.ema.ema or trainer.model\n model = model.half() if self.args.half else model.float()\n self.loss = torch.zeros_like(trainer.loss_items, device=trainer.device)\n self.args.plots &= trainer.stopper.possible_stop or (trainer.epoch == trainer.epochs - 1)\n model.eval()\n else:\n if str(self.args.model).endswith(\".yaml\") and model is None:\n LOGGER.warning(\"validating an untrained model YAML will result in 0 mAP.\")\n callbacks.add_integration_callbacks(self)\n model = AutoBackend(\n weights=model or self.args.model,\n device=select_device(self.args.device, self.args.batch),\n dnn=self.args.dnn,\n data=self.args.data,\n fp16=self.args.half,\n )\n self.device = model.device # update device\n self.args.half = model.fp16 # update half\n stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine\n imgsz = check_imgsz(self.args.imgsz, stride=stride)\n if engine:\n self.args.batch = model.batch_size\n elif not (pt or jit or getattr(model, \"dynamic\", False)):\n self.args.batch = model.metadata.get(\"batch\", 1) # export.py models default to batch-size 1\n LOGGER.info(f\"Setting batch={self.args.batch} input of shape ({self.args.batch}, 3, {imgsz}, {imgsz})\")\n\n if str(self.args.data).rsplit(\".\", 1)[-1] in {\"yaml\", \"yml\"}:\n self.data = check_det_dataset(self.args.data)\n elif self.args.task == \"classify\":\n self.data = check_cls_dataset(self.args.data, split=self.args.split)\n else:\n raise FileNotFoundError(emojis(f\"Dataset '{self.args.data}' for task={self.args.task} not found ❌\"))\n\n if self.device.type in {\"cpu\", \"mps\"}:\n self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading\n if not (pt or (getattr(model, \"dynamic\", False) and not model.imx)):\n self.args.rect = False\n self.stride = model.stride # used in get_dataloader() for padding\n self.dataloader = self.dataloader or self.get_dataloader(self.data.get(self.args.split), self.args.batch)\n\n model.eval()\n model.warmup(imgsz=(1 if pt else self.args.batch, self.data[\"channels\"], imgsz, imgsz)) # warmup\n\n self.run_callbacks(\"on_val_start\")\n dt = (\n Profile(device=self.device),\n Profile(device=self.device),\n Profile(device=self.device),\n Profile(device=self.device),\n )\n bar = TQDM(self.dataloader, desc=self.get_desc(), total=len(self.dataloader))\n self.init_metrics(de_parallel(model))\n self.jdict = [] # empty before each val\n for batch_i, batch in enumerate(bar):\n self.run_callbacks(\"on_val_batch_start\")\n self.batch_i = batch_i\n # Preprocess\n with dt[0]:\n batch = self.preprocess(batch)\n\n # Inference\n with dt[1]:\n preds = model(batch[\"img\"], augment=augment)\n\n # Loss\n with dt[2]:\n if self.training:\n self.loss += model.loss(batch, preds)[1]\n\n # Postprocess\n with dt[3]:\n preds = self.postprocess(preds)\n\n self.update_metrics(preds, batch)\n if self.args.plots and batch_i < 3:\n self.plot_val_samples(batch, batch_i)\n self.plot_predictions(batch, preds, batch_i)\n\n self.run_callbacks(\"on_val_batch_end\")\n stats = self.get_stats()\n self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1e3 for x in dt)))\n self.finalize_metrics()\n self.print_results()\n self.run_callbacks(\"on_val_end\")\n if self.training:\n model.float()\n results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix=\"val\")}\n return {k: round(float(v), 5) for k, v in results.items()} # return results as 5 decimal place floats\n else:\n LOGGER.info(\n \"Speed: {:.1f}ms preprocess, {:.1f}ms inference, {:.1f}ms loss, {:.1f}ms postprocess per image\".format(\n *tuple(self.speed.values())\n )\n )\n if self.args.save_json and self.jdict:\n with open(str(self.save_dir / \"predictions.json\"), \"w\", encoding=\"utf-8\") as f:\n LOGGER.info(f\"Saving {f.name}...\")\n json.dump(self.jdict, f) # flatten and save\n stats = self.eval_json(stats) # update stats\n if self.args.plots or self.args.save_json:\n LOGGER.info(f\"Results saved to {colorstr('bold', self.save_dir)}\")\n return stats\n\n def match_predictions(\n self, pred_classes: torch.Tensor, true_classes: torch.Tensor, iou: torch.Tensor, use_scipy: bool = False\n ) -> torch.Tensor:\n \"\"\"\n Match predictions to ground truth objects using IoU.\n\n Args:\n pred_classes (torch.Tensor): Predicted class indices of shape (N,).\n true_classes (torch.Tensor): Target class indices of shape (M,).\n iou (torch.Tensor): An NxM tensor containing the pairwise IoU values for predictions and ground truth.\n use_scipy (bool, optional): Whether to use scipy for matching (more precise).\n\n Returns:\n (torch.Tensor): Correct tensor of shape (N, 10) for 10 IoU thresholds.\n \"\"\"\n # Dx10 matrix, where D - detections, 10 - IoU thresholds\n correct = np.zeros((pred_classes.shape[0], self.iouv.shape[0])).astype(bool)\n # LxD matrix where L - labels (rows), D - detections (columns)\n correct_class = true_classes[:, None] == pred_classes\n iou = iou * correct_class # zero out the wrong classes\n iou = iou.cpu().numpy()\n for i, threshold in enumerate(self.iouv.cpu().tolist()):\n if use_scipy:\n # WARNING: known issue that reduces mAP in https://github.com/ultralytics/ultralytics/pull/4708\n import scipy # scope import to avoid importing for all commands\n\n cost_matrix = iou * (iou >= threshold)\n if cost_matrix.any():\n labels_idx, detections_idx = scipy.optimize.linear_sum_assignment(cost_matrix)\n valid = cost_matrix[labels_idx, detections_idx] > 0\n if valid.any():\n correct[detections_idx[valid], i] = True\n else:\n matches = np.nonzero(iou >= threshold) # IoU > threshold and classes match\n matches = np.array(matches).T\n if matches.shape[0]:\n if matches.shape[0] > 1:\n matches = matches[iou[matches[:, 0], matches[:, 1]].argsort()[::-1]]\n matches = matches[np.unique(matches[:, 1], return_index=True)[1]]\n matches = matches[np.unique(matches[:, 0], return_index=True)[1]]\n correct[matches[:, 1].astype(int), i] = True\n return torch.tensor(correct, dtype=torch.bool, device=pred_classes.device)\n\n def add_callback(self, event: str, callback):\n \"\"\"Append the given callback to the specified event.\"\"\"\n self.callbacks[event].append(callback)\n\n def run_callbacks(self, event: str):\n \"\"\"Run all callbacks associated with a specified event.\"\"\"\n for callback in self.callbacks.get(event, []):\n callback(self)\n\n def get_dataloader(self, dataset_path, batch_size):\n \"\"\"Get data loader from dataset path and batch size.\"\"\"\n raise NotImplementedError(\"get_dataloader function not implemented for this validator\")\n\n def build_dataset(self, img_path):\n \"\"\"Build dataset from image path.\"\"\"\n raise NotImplementedError(\"build_dataset function not implemented in validator\")\n\n def preprocess(self, batch):\n \"\"\"Preprocess an input batch.\"\"\"\n return batch\n\n def postprocess(self, preds):\n \"\"\"Postprocess the predictions.\"\"\"\n return preds\n\n def init_metrics(self, model):\n \"\"\"Initialize performance metrics for the YOLO model.\"\"\"\n pass\n\n def update_metrics(self, preds, batch):\n \"\"\"Update metrics based on predictions and batch.\"\"\"\n pass\n\n def finalize_metrics(self):\n \"\"\"Finalize and return all metrics.\"\"\"\n pass\n\n def get_stats(self):\n \"\"\"Return statistics about the model's performance.\"\"\"\n return {}\n\n def print_results(self):\n \"\"\"Print the results of the model's predictions.\"\"\"\n pass\n\n def get_desc(self):\n \"\"\"Get description of the YOLO model.\"\"\"\n pass\n\n @property\n def metric_keys(self):\n \"\"\"Return the metric keys used in YOLO training/validation.\"\"\"\n return []\n\n def on_plot(self, name, data=None):\n \"\"\"Register plots for visualization.\"\"\"\n self.plots[Path(name)] = {\"data\": data, \"timestamp\": time.time()}\n\n def plot_val_samples(self, batch, ni):\n \"\"\"Plot validation samples during training.\"\"\"\n pass\n\n def plot_predictions(self, batch, preds, ni):\n \"\"\"Plot YOLO model predictions on batch images.\"\"\"\n pass\n\n def pred_to_json(self, preds, batch):\n \"\"\"Convert predictions to JSON format.\"\"\"\n pass\n\n def eval_json(self, stats):\n \"\"\"Evaluate and return JSON format of prediction statistics.\"\"\"\n pass", "chunk_type": "class", "name": "BaseValidator", "file_path": "ultralytics\\ultralytics\\engine\\validator.py", "start_line": 42, "end_line": 366, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": "A base class for creating validators.\n\nThis class provides the foundation for validation processes, including model evaluation, metric computation, and\nresult visualization.\n\nAttributes:\n args (SimpleNamespace): Configuration for the validator.\n dataloader (DataLoader): Dataloader to use for validation.\n model (nn.Module): Model to validate.\n data (dict): Data dictionary containing dataset information.\n device (torch.device): Device to use for validation.\n batch_i (int): Current batch index.\n training (bool): Whether the model is in training mode.\n names (dict): Class names mapping.\n seen (int): Number of images seen so far during validation.\n stats (dict): Statistics collected during validation.\n confusion_matrix: Confusion matrix for classification evaluation.\n nc (int): Number of classes.\n iouv (torch.Tensor): IoU thresholds from 0.50 to 0.95 in spaces of 0.05.\n jdict (list): List to store JSON validation results.\n speed (dict): Dictionary with keys 'preprocess', 'inference', 'loss', 'postprocess' and their respective\n batch processing times in milliseconds.\n save_dir (Path): Directory to save results.\n plots (dict): Dictionary to store plots for visualization.\n callbacks (dict): Dictionary to store various callback functions.\n stride (int): Model stride for padding calculations.\n loss (torch.Tensor): Accumulated loss during training validation.\n\nMethods:\n __call__: Execute validation process, running inference on dataloader and computing performance metrics.\n match_predictions: Match predictions to ground truth objects using IoU.\n add_callback: Append the given callback to the specified event.\n run_callbacks: Run all callbacks associated with a specified event.\n get_dataloader: Get data loader from dataset path and batch size.\n build_dataset: Build dataset from image path.\n preprocess: Preprocess an input batch.\n postprocess: Postprocess the predictions.\n init_metrics: Initialize performance metrics for the YOLO model.\n update_metrics: Update metrics based on predictions and batch.\n finalize_metrics: Finalize and return all metrics.\n get_stats: Return statistics about the model's performance.\n print_results: Print the results of the model's predictions.\n get_desc: Get description of the YOLO model.\n on_plot: Register plots for visualization.\n plot_val_samples: Plot validation samples during training.\n plot_predictions: Plot YOLO model predictions on batch images.\n pred_to_json: Convert predictions to JSON format.\n eval_json: Evaluate and return JSON format of prediction statistics.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "json", "time", "pathlib.Path", "numpy", "torch", "ultralytics.cfg.get_cfg", "ultralytics.cfg.get_save_dir", "ultralytics.data.utils.check_cls_dataset", "ultralytics.data.utils.check_det_dataset", "ultralytics.nn.autobackend.AutoBackend", "ultralytics.utils.LOGGER", "ultralytics.utils.TQDM", "ultralytics.utils.callbacks", "ultralytics.utils.colorstr", "ultralytics.utils.emojis", "ultralytics.utils.checks.check_imgsz", "ultralytics.utils.ops.Profile", "ultralytics.utils.torch_utils.de_parallel", "ultralytics.utils.torch_utils.select_device", "ultralytics.utils.torch_utils.smart_inference_mode", "scipy" ], "chunk_id": "class_BaseValidator_787359ce" }, { "content": "import requests", "chunk_type": "import", "name": "requests", "file_path": "ultralytics\\ultralytics\\hub\\auth.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 15, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_requests_d146b854" }, { "content": "from ultralytics.hub.utils import HUB_API_ROOT, HUB_WEB_ROOT, PREFIX, request_with_credentials", "chunk_type": "import", "name": "HUB_API_ROOT, HUB_WEB_ROOT, PREFIX, request_with_credentials", "file_path": "ultralytics\\ultralytics\\hub\\auth.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 94, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_HUB_API_ROOT, HUB_WEB_ROOT, PREFIX, request_with_credentials_f412c55a" }, { "content": "from ultralytics.utils import IS_COLAB, LOGGER, SETTINGS, emojis", "chunk_type": "import", "name": "IS_COLAB, LOGGER, SETTINGS, emojis", "file_path": "ultralytics\\ultralytics\\hub\\auth.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 64, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_IS_COLAB, LOGGER, SETTINGS, emojis_8fc6ca17" }, { "content": "API_KEY_URL = f\"{HUB_WEB_ROOT}/settings?tab=api+keys\"", "chunk_type": "variable", "name": "API_KEY_URL", "file_path": "ultralytics\\ultralytics\\hub\\auth.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 53, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_API_KEY_URL_39a2ac38" }, { "content": "class Auth:\n \"\"\"\n Manages authentication processes including API key handling, cookie-based authentication, and header generation.\n\n The class supports different methods of authentication:\n 1. Directly using an API key.\n 2. Authenticating using browser cookies (specifically in Google Colab).\n 3. Prompting the user to enter an API key.\n\n Attributes:\n id_token (str | bool): Token used for identity verification, initialized as False.\n api_key (str | bool): API key for authentication, initialized as False.\n model_key (bool): Placeholder for model key, initialized as False.\n\n Methods:\n authenticate: Attempt to authenticate with the server using either id_token or API key.\n auth_with_cookies: Attempt to fetch authentication via cookies and set id_token.\n get_auth_header: Get the authentication header for making API requests.\n request_api_key: Prompt the user to input their API key.\n\n Examples:\n Initialize Auth with an API key\n >>> auth = Auth(api_key=\"your_api_key_here\")\n\n Initialize Auth without API key (will prompt for input)\n >>> auth = Auth()\n \"\"\"\n\n id_token = api_key = model_key = False\n\n def __init__(self, api_key: str = \"\", verbose: bool = False):\n \"\"\"\n Initialize Auth class and authenticate user.\n\n Handles API key validation, Google Colab authentication, and new key requests. Updates SETTINGS upon successful\n authentication.\n\n Args:\n api_key (str): API key or combined key_id format.\n verbose (bool): Enable verbose logging.\n \"\"\"\n # Split the input API key in case it contains a combined key_model and keep only the API key part\n api_key = api_key.split(\"_\", 1)[0]\n\n # Set API key attribute as value passed or SETTINGS API key if none passed\n self.api_key = api_key or SETTINGS.get(\"api_key\", \"\")\n\n # If an API key is provided\n if self.api_key:\n # If the provided API key matches the API key in the SETTINGS\n if self.api_key == SETTINGS.get(\"api_key\"):\n # Log that the user is already logged in\n if verbose:\n LOGGER.info(f\"{PREFIX}Authenticated ✅\")\n return\n else:\n # Attempt to authenticate with the provided API key\n success = self.authenticate()\n # If the API key is not provided and the environment is a Google Colab notebook\n elif IS_COLAB:\n # Attempt to authenticate using browser cookies\n success = self.auth_with_cookies()\n else:\n # Request an API key\n success = self.request_api_key()\n\n # Update SETTINGS with the new API key after successful authentication\n if success:\n SETTINGS.update({\"api_key\": self.api_key})\n # Log that the new login was successful\n if verbose:\n LOGGER.info(f\"{PREFIX}New authentication successful ✅\")\n elif verbose:\n LOGGER.info(f\"{PREFIX}Get API key from {API_KEY_URL} and then run 'yolo login API_KEY'\")\n\n def request_api_key(self, max_attempts: int = 3) -> bool:\n \"\"\"\n Prompt the user to input their API key.\n\n Args:\n max_attempts (int): Maximum number of authentication attempts.\n\n Returns:\n (bool): True if authentication is successful, False otherwise.\n \"\"\"\n import getpass\n\n for attempts in range(max_attempts):\n LOGGER.info(f\"{PREFIX}Login. Attempt {attempts + 1} of {max_attempts}\")\n input_key = getpass.getpass(f\"Enter API key from {API_KEY_URL} \")\n self.api_key = input_key.split(\"_\", 1)[0] # remove model id if present\n if self.authenticate():\n return True\n raise ConnectionError(emojis(f\"{PREFIX}Failed to authenticate ❌\"))\n\n def authenticate(self) -> bool:\n \"\"\"\n Attempt to authenticate with the server using either id_token or API key.\n\n Returns:\n (bool): True if authentication is successful, False otherwise.\n \"\"\"\n try:\n if header := self.get_auth_header():\n r = requests.post(f\"{HUB_API_ROOT}/v1/auth\", headers=header)\n if not r.json().get(\"success\", False):\n raise ConnectionError(\"Unable to authenticate.\")\n return True\n raise ConnectionError(\"User has not authenticated locally.\")\n except ConnectionError:\n self.id_token = self.api_key = False # reset invalid\n LOGGER.warning(f\"{PREFIX}Invalid API key\")\n return False\n\n def auth_with_cookies(self) -> bool:\n \"\"\"\n Attempt to fetch authentication via cookies and set id_token.\n\n User must be logged in to HUB and running in a supported browser.\n\n Returns:\n (bool): True if authentication is successful, False otherwise.\n \"\"\"\n if not IS_COLAB:\n return False # Currently only works with Colab\n try:\n authn = request_with_credentials(f\"{HUB_API_ROOT}/v1/auth/auto\")\n if authn.get(\"success\", False):\n self.id_token = authn.get(\"data\", {}).get(\"idToken\", None)\n self.authenticate()\n return True\n raise ConnectionError(\"Unable to fetch browser authentication details.\")\n except ConnectionError:\n self.id_token = False # reset invalid\n return False\n\n def get_auth_header(self):\n \"\"\"\n Get the authentication header for making API requests.\n\n Returns:\n (dict | None): The authentication header if id_token or API key is set, None otherwise.\n \"\"\"\n if self.id_token:\n return {\"authorization\": f\"Bearer {self.id_token}\"}\n elif self.api_key:\n return {\"x-api-key\": self.api_key}", "chunk_type": "class", "name": "Auth", "file_path": "ultralytics\\ultralytics\\hub\\auth.py", "start_line": 11, "end_line": 157, "start_col": 0, "end_col": 46, "parent_name": null, "docstring": "Manages authentication processes including API key handling, cookie-based authentication, and header generation.\n\nThe class supports different methods of authentication:\n1. Directly using an API key.\n2. Authenticating using browser cookies (specifically in Google Colab).\n3. Prompting the user to enter an API key.\n\nAttributes:\n id_token (str | bool): Token used for identity verification, initialized as False.\n api_key (str | bool): API key for authentication, initialized as False.\n model_key (bool): Placeholder for model key, initialized as False.\n\nMethods:\n authenticate: Attempt to authenticate with the server using either id_token or API key.\n auth_with_cookies: Attempt to fetch authentication via cookies and set id_token.\n get_auth_header: Get the authentication header for making API requests.\n request_api_key: Prompt the user to input their API key.\n\nExamples:\n Initialize Auth with an API key\n >>> auth = Auth(api_key=\"your_api_key_here\")\n\n Initialize Auth without API key (will prompt for input)\n >>> auth = Auth()", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "requests", "ultralytics.hub.utils.HUB_API_ROOT", "ultralytics.hub.utils.HUB_WEB_ROOT", "ultralytics.hub.utils.PREFIX", "ultralytics.hub.utils.request_with_credentials", "ultralytics.utils.IS_COLAB", "ultralytics.utils.LOGGER", "ultralytics.utils.SETTINGS", "ultralytics.utils.emojis", "getpass" ], "chunk_id": "class_Auth_061962a9" }, { "content": "import shutil", "chunk_type": "import", "name": "shutil", "file_path": "ultralytics\\ultralytics\\hub\\session.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 13, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_shutil_0f9d7aa7" }, { "content": "import threading", "chunk_type": "import", "name": "threading", "file_path": "ultralytics\\ultralytics\\hub\\session.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 16, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_threading_f2ab031b" }, { "content": "import time", "chunk_type": "import", "name": "time", "file_path": "ultralytics\\ultralytics\\hub\\session.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 11, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_time_f6439120" }, { "content": "from http import HTTPStatus", "chunk_type": "import", "name": "HTTPStatus", "file_path": "ultralytics\\ultralytics\\hub\\session.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 27, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_HTTPStatus_31a72e26" }, { "content": "from pathlib import Path", "chunk_type": "import", "name": "Path", "file_path": "ultralytics\\ultralytics\\hub\\session.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Path_073a708b" }, { "content": "from typing import Any, Dict, Optional", "chunk_type": "import", "name": "Any, Dict, Optional", "file_path": "ultralytics\\ultralytics\\hub\\session.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 38, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Any, Dict, Optional_67195a33" }, { "content": "from urllib.parse import parse_qs, urlparse", "chunk_type": "import", "name": "parse_qs, urlparse", "file_path": "ultralytics\\ultralytics\\hub\\session.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 43, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_parse_qs, urlparse_26351545" }, { "content": "import requests", "chunk_type": "import", "name": "requests", "file_path": "ultralytics\\ultralytics\\hub\\session.py", "start_line": 11, "end_line": 11, "start_col": 0, "end_col": 15, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_requests_0a0de897" }, { "content": "from ultralytics import __version__", "chunk_type": "import", "name": "__version__", "file_path": "ultralytics\\ultralytics\\hub\\session.py", "start_line": 13, "end_line": 13, "start_col": 0, "end_col": 35, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import___version___b913aa68" }, { "content": "from ultralytics.hub.utils import HELP_MSG, HUB_WEB_ROOT, PREFIX", "chunk_type": "import", "name": "HELP_MSG, HUB_WEB_ROOT, PREFIX", "file_path": "ultralytics\\ultralytics\\hub\\session.py", "start_line": 14, "end_line": 14, "start_col": 0, "end_col": 64, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_HELP_MSG, HUB_WEB_ROOT, PREFIX_dcd7e13a" }, { "content": "from ultralytics.utils import IS_COLAB, LOGGER, SETTINGS, TQDM, checks, emojis", "chunk_type": "import", "name": "IS_COLAB, LOGGER, SETTINGS, TQDM, checks, emojis", "file_path": "ultralytics\\ultralytics\\hub\\session.py", "start_line": 15, "end_line": 15, "start_col": 0, "end_col": 78, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_IS_COLAB, LOGGER, SETTINGS, TQDM, checks, emojis_34229a3e" }, { "content": "from ultralytics.utils.errors import HUBModelError", "chunk_type": "import", "name": "HUBModelError", "file_path": "ultralytics\\ultralytics\\hub\\session.py", "start_line": 16, "end_line": 16, "start_col": 0, "end_col": 50, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_HUBModelError_419ed511" }, { "content": "AGENT_NAME = f\"python-{__version__}-colab\" if IS_COLAB else f\"python-{__version__}-local\"", "chunk_type": "variable", "name": "AGENT_NAME", "file_path": "ultralytics\\ultralytics\\hub\\session.py", "start_line": 18, "end_line": 18, "start_col": 0, "end_col": 89, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_AGENT_NAME_3f69279b" }, { "content": "class HUBTrainingSession:\n \"\"\"\n HUB training session for Ultralytics HUB YOLO models.\n\n This class encapsulates the functionality for interacting with Ultralytics HUB during model training, including\n model creation, metrics tracking, and checkpoint uploading.\n\n Attributes:\n model_id (str): Identifier for the YOLO model being trained.\n model_url (str): URL for the model in Ultralytics HUB.\n rate_limits (Dict[str, int]): Rate limits for different API calls in seconds.\n timers (Dict[str, Any]): Timers for rate limiting.\n metrics_queue (Dict[str, Any]): Queue for the model's metrics.\n metrics_upload_failed_queue (Dict[str, Any]): Queue for metrics that failed to upload.\n model (Any): Model data fetched from Ultralytics HUB.\n model_file (str): Path to the model file.\n train_args (Dict[str, Any]): Arguments for training the model.\n client (Any): Client for interacting with Ultralytics HUB.\n filename (str): Filename of the model.\n\n Examples:\n Create a training session with a model URL\n >>> session = HUBTrainingSession(\"https://hub.ultralytics.com/models/example-model\")\n >>> session.upload_metrics()\n \"\"\"\n\n def __init__(self, identifier: str):\n \"\"\"\n Initialize the HUBTrainingSession with the provided model identifier.\n\n Args:\n identifier (str): Model identifier used to initialize the HUB training session. It can be a URL string\n or a model key with specific format.\n\n Raises:\n ValueError: If the provided model identifier is invalid.\n ConnectionError: If connecting with global API key is not supported.\n ModuleNotFoundError: If hub-sdk package is not installed.\n \"\"\"\n from hub_sdk import HUBClient\n\n self.rate_limits = {\"metrics\": 3, \"ckpt\": 900, \"heartbeat\": 300} # rate limits (seconds)\n self.metrics_queue = {} # holds metrics for each epoch until upload\n self.metrics_upload_failed_queue = {} # holds metrics for each epoch if upload failed\n self.timers = {} # holds timers in ultralytics/utils/callbacks/hub.py\n self.model = None\n self.model_url = None\n self.model_file = None\n self.train_args = None\n\n # Parse input\n api_key, model_id, self.filename = self._parse_identifier(identifier)\n\n # Get credentials\n active_key = api_key or SETTINGS.get(\"api_key\")\n credentials = {\"api_key\": active_key} if active_key else None # set credentials\n\n # Initialize client\n self.client = HUBClient(credentials)\n\n # Load models\n try:\n if model_id:\n self.load_model(model_id) # load existing model\n else:\n self.model = self.client.model() # load empty model\n except Exception:\n if identifier.startswith(f\"{HUB_WEB_ROOT}/models/\") and not self.client.authenticated:\n LOGGER.warning(\n f\"{PREFIX}Please log in using 'yolo login API_KEY'. \"\n \"You can find your API Key at: https://hub.ultralytics.com/settings?tab=api+keys.\"\n )\n\n @classmethod\n def create_session(cls, identifier: str, args: Optional[Dict[str, Any]] = None):\n \"\"\"\n Create an authenticated HUBTrainingSession or return None.\n\n Args:\n identifier (str): Model identifier used to initialize the HUB training session.\n args (Dict[str, Any], optional): Arguments for creating a new model if identifier is not a HUB model URL.\n\n Returns:\n session (HUBTrainingSession | None): An authenticated session or None if creation fails.\n \"\"\"\n try:\n session = cls(identifier)\n if args and not identifier.startswith(f\"{HUB_WEB_ROOT}/models/\"): # not a HUB model URL\n session.create_model(args)\n assert session.model.id, \"HUB model not loaded correctly\"\n return session\n # PermissionError and ModuleNotFoundError indicate hub-sdk not installed\n except (PermissionError, ModuleNotFoundError, AssertionError):\n return None\n\n def load_model(self, model_id: str):\n \"\"\"\n Load an existing model from Ultralytics HUB using the provided model identifier.\n\n Args:\n model_id (str): The identifier of the model to load.\n\n Raises:\n ValueError: If the specified HUB model does not exist.\n \"\"\"\n self.model = self.client.model(model_id)\n if not self.model.data: # then model does not exist\n raise ValueError(emojis(\"❌ The specified HUB model does not exist\")) # TODO: improve error handling\n\n self.model_url = f\"{HUB_WEB_ROOT}/models/{self.model.id}\"\n if self.model.is_trained():\n LOGGER.info(f\"Loading trained HUB model {self.model_url} 🚀\")\n url = self.model.get_weights_url(\"best\") # download URL with auth\n self.model_file = checks.check_file(url, download_dir=Path(SETTINGS[\"weights_dir\"]) / \"hub\" / self.model.id)\n return\n\n # Set training args and start heartbeats for HUB to monitor agent\n self._set_train_args()\n self.model.start_heartbeat(self.rate_limits[\"heartbeat\"])\n LOGGER.info(f\"{PREFIX}View model at {self.model_url} 🚀\")\n\n def create_model(self, model_args: Dict[str, Any]):\n \"\"\"\n Initialize a HUB training session with the specified model arguments.\n\n Args:\n model_args (Dict[str, Any]): Arguments for creating the model, including batch size, epochs, image size,\n etc.\n\n Returns:\n (None): If the model could not be created.\n \"\"\"\n payload = {\n \"config\": {\n \"batchSize\": model_args.get(\"batch\", -1),\n \"epochs\": model_args.get(\"epochs\", 300),\n \"imageSize\": model_args.get(\"imgsz\", 640),\n \"patience\": model_args.get(\"patience\", 100),\n \"device\": str(model_args.get(\"device\", \"\")), # convert None to string\n \"cache\": str(model_args.get(\"cache\", \"ram\")), # convert True, False, None to string\n },\n \"dataset\": {\"name\": model_args.get(\"data\")},\n \"lineage\": {\n \"architecture\": {\"name\": self.filename.replace(\".pt\", \"\").replace(\".yaml\", \"\")},\n \"parent\": {},\n },\n \"meta\": {\"name\": self.filename},\n }\n\n if self.filename.endswith(\".pt\"):\n payload[\"lineage\"][\"parent\"][\"name\"] = self.filename\n\n self.model.create_model(payload)\n\n # Model could not be created\n # TODO: improve error handling\n if not self.model.id:\n return None\n\n self.model_url = f\"{HUB_WEB_ROOT}/models/{self.model.id}\"\n\n # Start heartbeats for HUB to monitor agent\n self.model.start_heartbeat(self.rate_limits[\"heartbeat\"])\n\n LOGGER.info(f\"{PREFIX}View model at {self.model_url} 🚀\")\n\n @staticmethod\n def _parse_identifier(identifier: str):\n \"\"\"\n Parse the given identifier to determine the type and extract relevant components.\n\n The method supports different identifier formats:\n - A HUB model URL https://hub.ultralytics.com/models/MODEL\n - A HUB model URL with API Key https://hub.ultralytics.com/models/MODEL?api_key=APIKEY\n - A local filename that ends with '.pt' or '.yaml'\n\n Args:\n identifier (str): The identifier string to be parsed.\n\n Returns:\n api_key (str | None): Extracted API key if present.\n model_id (str | None): Extracted model ID if present.\n filename (str | None): Extracted filename if present.\n\n Raises:\n HUBModelError: If the identifier format is not recognized.\n \"\"\"\n api_key, model_id, filename = None, None, None\n if str(identifier).endswith((\".pt\", \".yaml\")):\n filename = identifier\n elif identifier.startswith(f\"{HUB_WEB_ROOT}/models/\"):\n parsed_url = urlparse(identifier)\n model_id = Path(parsed_url.path).stem # handle possible final backslash robustly\n query_params = parse_qs(parsed_url.query) # dictionary, i.e. {\"api_key\": [\"API_KEY_HERE\"]}\n api_key = query_params.get(\"api_key\", [None])[0]\n else:\n raise HUBModelError(f\"model='{identifier} invalid, correct format is {HUB_WEB_ROOT}/models/MODEL_ID\")\n return api_key, model_id, filename\n\n def _set_train_args(self):\n \"\"\"\n Initialize training arguments and create a model entry on the Ultralytics HUB.\n\n This method sets up training arguments based on the model's state and updates them with any additional\n arguments provided. It handles different states of the model, such as whether it's resumable, pretrained,\n or requires specific file setup.\n\n Raises:\n ValueError: If the model is already trained, if required dataset information is missing, or if there are\n issues with the provided training arguments.\n \"\"\"\n if self.model.is_resumable():\n # Model has saved weights\n self.train_args = {\"data\": self.model.get_dataset_url(), \"resume\": True}\n self.model_file = self.model.get_weights_url(\"last\")\n else:\n # Model has no saved weights\n self.train_args = self.model.data.get(\"train_args\") # new response\n\n # Set the model file as either a *.pt or *.yaml file\n self.model_file = (\n self.model.get_weights_url(\"parent\") if self.model.is_pretrained() else self.model.get_architecture()\n )\n\n if \"data\" not in self.train_args:\n # RF bug - datasets are sometimes not exported\n raise ValueError(\"Dataset may still be processing. Please wait a minute and try again.\")\n\n self.model_file = checks.check_yolov5u_filename(self.model_file, verbose=False) # YOLOv5->YOLOv5u\n self.model_id = self.model.id\n\n def request_queue(\n self,\n request_func,\n retry: int = 3,\n timeout: int = 30,\n thread: bool = True,\n verbose: bool = True,\n progress_total: Optional[int] = None,\n stream_response: Optional[bool] = None,\n *args,\n **kwargs,\n ):\n \"\"\"\n Execute request_func with retries, timeout handling, optional threading, and progress tracking.\n\n Args:\n request_func (callable): The function to execute.\n retry (int): Number of retry attempts.\n timeout (int): Maximum time to wait for the request to complete.\n thread (bool): Whether to run the request in a separate thread.\n verbose (bool): Whether to log detailed messages.\n progress_total (int, optional): Total size for progress tracking.\n stream_response (bool, optional): Whether to stream the response.\n *args (Any): Additional positional arguments for request_func.\n **kwargs (Any): Additional keyword arguments for request_func.\n\n Returns:\n (requests.Response | None): The response object if thread=False, otherwise None.\n \"\"\"\n\n def retry_request():\n \"\"\"Attempt to call request_func with retries, timeout, and optional threading.\"\"\"\n t0 = time.time() # Record the start time for the timeout\n response = None\n for i in range(retry + 1):\n if (time.time() - t0) > timeout:\n LOGGER.warning(f\"{PREFIX}Timeout for request reached. {HELP_MSG}\")\n break # Timeout reached, exit loop\n\n response = request_func(*args, **kwargs)\n if response is None:\n LOGGER.warning(f\"{PREFIX}Received no response from the request. {HELP_MSG}\")\n time.sleep(2**i) # Exponential backoff before retrying\n continue # Skip further processing and retry\n\n if progress_total:\n self._show_upload_progress(progress_total, response)\n elif stream_response:\n self._iterate_content(response)\n\n if HTTPStatus.OK <= response.status_code < HTTPStatus.MULTIPLE_CHOICES:\n # if request related to metrics upload\n if kwargs.get(\"metrics\"):\n self.metrics_upload_failed_queue = {}\n return response # Success, no need to retry\n\n if i == 0:\n # Initial attempt, check status code and provide messages\n message = self._get_failure_message(response, retry, timeout)\n\n if verbose:\n LOGGER.warning(f\"{PREFIX}{message} {HELP_MSG} ({response.status_code})\")\n\n if not self._should_retry(response.status_code):\n LOGGER.warning(f\"{PREFIX}Request failed. {HELP_MSG} ({response.status_code}\")\n break # Not an error that should be retried, exit loop\n\n time.sleep(2**i) # Exponential backoff for retries\n\n # if request related to metrics upload and exceed retries\n if response is None and kwargs.get(\"metrics\"):\n self.metrics_upload_failed_queue.update(kwargs.get(\"metrics\"))\n\n return response\n\n if thread:\n # Start a new thread to run the retry_request function\n threading.Thread(target=retry_request, daemon=True).start()\n else:\n # If running in the main thread, call retry_request directly\n return retry_request()\n\n @staticmethod\n def _should_retry(status_code: int) -> bool:\n \"\"\"Determine if a request should be retried based on the HTTP status code.\"\"\"\n retry_codes = {\n HTTPStatus.REQUEST_TIMEOUT,\n HTTPStatus.BAD_GATEWAY,\n HTTPStatus.GATEWAY_TIMEOUT,\n }\n return status_code in retry_codes\n\n def _get_failure_message(self, response: requests.Response, retry: int, timeout: int) -> str:\n \"\"\"\n Generate a retry message based on the response status code.\n\n Args:\n response (requests.Response): The HTTP response object.\n retry (int): The number of retry attempts allowed.\n timeout (int): The maximum timeout duration.\n\n Returns:\n (str): The retry message.\n \"\"\"\n if self._should_retry(response.status_code):\n return f\"Retrying {retry}x for {timeout}s.\" if retry else \"\"\n elif response.status_code == HTTPStatus.TOO_MANY_REQUESTS: # rate limit\n headers = response.headers\n return (\n f\"Rate limit reached ({headers['X-RateLimit-Remaining']}/{headers['X-RateLimit-Limit']}). \"\n f\"Please retry after {headers['Retry-After']}s.\"\n )\n else:\n try:\n return response.json().get(\"message\", \"No JSON message.\")\n except AttributeError:\n return \"Unable to read JSON.\"\n\n def upload_metrics(self):\n \"\"\"Upload model metrics to Ultralytics HUB.\"\"\"\n return self.request_queue(self.model.upload_metrics, metrics=self.metrics_queue.copy(), thread=True)\n\n def upload_model(\n self,\n epoch: int,\n weights: str,\n is_best: bool = False,\n map: float = 0.0,\n final: bool = False,\n ) -> None:\n \"\"\"\n Upload a model checkpoint to Ultralytics HUB.\n\n Args:\n epoch (int): The current training epoch.\n weights (str): Path to the model weights file.\n is_best (bool): Indicates if the current model is the best one so far.\n map (float): Mean average precision of the model.\n final (bool): Indicates if the model is the final model after training.\n \"\"\"\n weights = Path(weights)\n if not weights.is_file():\n last = weights.with_name(f\"last{weights.suffix}\")\n if final and last.is_file():\n LOGGER.warning(\n f\"{PREFIX} Model 'best.pt' not found, copying 'last.pt' to 'best.pt' and uploading. \"\n \"This often happens when resuming training in transient environments like Google Colab. \"\n \"For more reliable training, consider using Ultralytics HUB Cloud. \"\n \"Learn more at https://docs.ultralytics.com/hub/cloud-training.\"\n )\n shutil.copy(last, weights) # copy last.pt to best.pt\n else:\n LOGGER.warning(f\"{PREFIX} Model upload issue. Missing model {weights}.\")\n return\n\n self.request_queue(\n self.model.upload_model,\n epoch=epoch,\n weights=str(weights),\n is_best=is_best,\n map=map,\n final=final,\n retry=10,\n timeout=3600,\n thread=not final,\n progress_total=weights.stat().st_size if final else None, # only show progress if final\n stream_response=True,\n )\n\n @staticmethod\n def _show_upload_progress(content_length: int, response: requests.Response) -> None:\n \"\"\"Display a progress bar to track the upload progress of a file download.\"\"\"\n with TQDM(total=content_length, unit=\"B\", unit_scale=True, unit_divisor=1024) as pbar:\n for data in response.iter_content(chunk_size=1024):\n pbar.update(len(data))\n\n @staticmethod\n def _iterate_content(response: requests.Response) -> None:\n \"\"\"Process the streamed HTTP response data.\"\"\"\n for _ in response.iter_content(chunk_size=1024):\n pass # Do nothing with data chunks", "chunk_type": "class", "name": "HUBTrainingSession", "file_path": "ultralytics\\ultralytics\\hub\\session.py", "start_line": 21, "end_line": 432, "start_col": 0, "end_col": 16, "parent_name": null, "docstring": "HUB training session for Ultralytics HUB YOLO models.\n\nThis class encapsulates the functionality for interacting with Ultralytics HUB during model training, including\nmodel creation, metrics tracking, and checkpoint uploading.\n\nAttributes:\n model_id (str): Identifier for the YOLO model being trained.\n model_url (str): URL for the model in Ultralytics HUB.\n rate_limits (Dict[str, int]): Rate limits for different API calls in seconds.\n timers (Dict[str, Any]): Timers for rate limiting.\n metrics_queue (Dict[str, Any]): Queue for the model's metrics.\n metrics_upload_failed_queue (Dict[str, Any]): Queue for metrics that failed to upload.\n model (Any): Model data fetched from Ultralytics HUB.\n model_file (str): Path to the model file.\n train_args (Dict[str, Any]): Arguments for training the model.\n client (Any): Client for interacting with Ultralytics HUB.\n filename (str): Filename of the model.\n\nExamples:\n Create a training session with a model URL\n >>> session = HUBTrainingSession(\"https://hub.ultralytics.com/models/example-model\")\n >>> session.upload_metrics()", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "shutil", "threading", "time", "http.HTTPStatus", "pathlib.Path", "typing.Any", "typing.Dict", "typing.Optional", "urllib.parse.parse_qs", "urllib.parse.urlparse", "requests", "ultralytics.__version__", "ultralytics.hub.utils.HELP_MSG", "ultralytics.hub.utils.HUB_WEB_ROOT", "ultralytics.hub.utils.PREFIX", "ultralytics.utils.IS_COLAB", "ultralytics.utils.LOGGER", "ultralytics.utils.SETTINGS", "ultralytics.utils.TQDM", "ultralytics.utils.checks", "ultralytics.utils.emojis", "ultralytics.utils.errors.HUBModelError", "hub_sdk.HUBClient" ], "chunk_id": "class_HUBTrainingSession_bdecfc1f" }, { "content": "import os", "chunk_type": "import", "name": "os", "file_path": "ultralytics\\ultralytics\\hub\\utils.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 9, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_os_5b122f17" }, { "content": "import random", "chunk_type": "import", "name": "random", "file_path": "ultralytics\\ultralytics\\hub\\utils.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 13, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_random_b441572e" }, { "content": "import threading", "chunk_type": "import", "name": "threading", "file_path": "ultralytics\\ultralytics\\hub\\utils.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 16, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_threading_0737ebda" }, { "content": "import time", "chunk_type": "import", "name": "time", "file_path": "ultralytics\\ultralytics\\hub\\utils.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 11, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_time_f59ceac1" }, { "content": "from pathlib import Path", "chunk_type": "import", "name": "Path", "file_path": "ultralytics\\ultralytics\\hub\\utils.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Path_f3506214" }, { "content": "from typing import Any, Optional", "chunk_type": "import", "name": "Any, Optional", "file_path": "ultralytics\\ultralytics\\hub\\utils.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 32, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Any, Optional_9831e6d4" }, { "content": "import requests", "chunk_type": "import", "name": "requests", "file_path": "ultralytics\\ultralytics\\hub\\utils.py", "start_line": 10, "end_line": 10, "start_col": 0, "end_col": 15, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_requests_348d32a8" }, { "content": "from ultralytics import __version__", "chunk_type": "import", "name": "__version__", "file_path": "ultralytics\\ultralytics\\hub\\utils.py", "start_line": 12, "end_line": 12, "start_col": 0, "end_col": 35, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import___version___3f66c63a" }, { "content": "from ultralytics.utils import (\n ARGV,\n ENVIRONMENT,\n IS_COLAB,\n IS_GIT_DIR,\n IS_PIP_PACKAGE,\n LOGGER,\n ONLINE,\n PYTHON_VERSION,\n RANK,\n SETTINGS,\n TESTS_RUNNING,\n TQDM,\n TryExcept,\n colorstr,\n get_git_origin_url,\n)", "chunk_type": "import", "name": "ARGV, ENVIRONMENT, IS_COLAB, IS_GIT_DIR, IS_PIP_PACKAGE, LOGGER, ONLINE, PYTHON_VERSION, RANK, SETTINGS, TESTS_RUNNING, TQDM, TryExcept, colorstr, get_git_origin_url", "file_path": "ultralytics\\ultralytics\\hub\\utils.py", "start_line": 13, "end_line": 29, "start_col": 0, "end_col": 1, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_ARGV, ENVIRONMENT, IS_COLAB, IS_GIT_DIR, IS_PIP_PACKAGE, LOGGER, ONLINE, PYTHON_VERSION, RANK, SETTINGS, TESTS_RUNNING, TQDM, TryExcept, colorstr, get_git_origin_url_fdb0fa09" }, { "content": "from ultralytics.utils.downloads import GITHUB_ASSETS_NAMES", "chunk_type": "import", "name": "GITHUB_ASSETS_NAMES", "file_path": "ultralytics\\ultralytics\\hub\\utils.py", "start_line": 30, "end_line": 30, "start_col": 0, "end_col": 59, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_GITHUB_ASSETS_NAMES_48452907" }, { "content": "from ultralytics.utils.torch_utils import get_cpu_info", "chunk_type": "import", "name": "get_cpu_info", "file_path": "ultralytics\\ultralytics\\hub\\utils.py", "start_line": 31, "end_line": 31, "start_col": 0, "end_col": 54, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_get_cpu_info_753d4be2" }, { "content": "HUB_API_ROOT = os.environ.get(\"ULTRALYTICS_HUB_API\", \"https://api.ultralytics.com\")", "chunk_type": "variable", "name": "HUB_API_ROOT", "file_path": "ultralytics\\ultralytics\\hub\\utils.py", "start_line": 33, "end_line": 33, "start_col": 0, "end_col": 83, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_HUB_API_ROOT_995fa169" }, { "content": "HUB_WEB_ROOT = os.environ.get(\"ULTRALYTICS_HUB_WEB\", \"https://hub.ultralytics.com\")", "chunk_type": "variable", "name": "HUB_WEB_ROOT", "file_path": "ultralytics\\ultralytics\\hub\\utils.py", "start_line": 34, "end_line": 34, "start_col": 0, "end_col": 83, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_HUB_WEB_ROOT_ba8102b0" }, { "content": "PREFIX = colorstr(\"Ultralytics HUB: \")", "chunk_type": "variable", "name": "PREFIX", "file_path": "ultralytics\\ultralytics\\hub\\utils.py", "start_line": 36, "end_line": 36, "start_col": 0, "end_col": 38, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_PREFIX_0dbfaa61" }, { "content": "HELP_MSG = \"If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance.\"", "chunk_type": "variable", "name": "HELP_MSG", "file_path": "ultralytics\\ultralytics\\hub\\utils.py", "start_line": 37, "end_line": 37, "start_col": 0, "end_col": 106, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_HELP_MSG_ca90ec1d" }, { "content": "def request_with_credentials(url: str) -> Any:\n \"\"\"\n Make an AJAX request with cookies attached in a Google Colab environment.\n\n Args:\n url (str): The URL to make the request to.\n\n Returns:\n (Any): The response data from the AJAX request.\n\n Raises:\n OSError: If the function is not run in a Google Colab environment.\n \"\"\"\n if not IS_COLAB:\n raise OSError(\"request_with_credentials() must run in a Colab environment\")\n from google.colab import output # noqa\n from IPython import display # noqa\n\n display.display(\n display.Javascript(\n f\"\"\"\n window._hub_tmp = new Promise((resolve, reject) => {{\n const timeout = setTimeout(() => reject(\"Failed authenticating existing browser session\"), 5000)\n fetch(\"{url}\", {{\n method: 'POST',\n credentials: 'include'\n }})\n .then((response) => resolve(response.json()))\n .then((json) => {{\n clearTimeout(timeout);\n }}).catch((err) => {{\n clearTimeout(timeout);\n reject(err);\n }});\n }});\n \"\"\"\n )\n )\n return output.eval_js(\"_hub_tmp\")", "chunk_type": "function", "name": "request_with_credentials", "file_path": "ultralytics\\ultralytics\\hub\\utils.py", "start_line": 40, "end_line": 78, "start_col": 0, "end_col": 37, "parent_name": null, "docstring": "Make an AJAX request with cookies attached in a Google Colab environment.\n\nArgs:\n url (str): The URL to make the request to.\n\nReturns:\n (Any): The response data from the AJAX request.\n\nRaises:\n OSError: If the function is not run in a Google Colab environment.", "parameters": [ "url: str" ], "return_type": "Any", "decorators": [], "complexity_score": 2, "dependencies": [ "os", "random", "threading", "time", "pathlib.Path", "typing.Any", "typing.Optional", "requests", "ultralytics.__version__", "ultralytics.utils.ARGV", "ultralytics.utils.ENVIRONMENT", "ultralytics.utils.IS_COLAB", "ultralytics.utils.IS_GIT_DIR", "ultralytics.utils.IS_PIP_PACKAGE", "ultralytics.utils.LOGGER", "ultralytics.utils.ONLINE", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.RANK", "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.TQDM", "ultralytics.utils.TryExcept", "ultralytics.utils.colorstr", "ultralytics.utils.get_git_origin_url", "ultralytics.utils.downloads.GITHUB_ASSETS_NAMES", "ultralytics.utils.torch_utils.get_cpu_info", "google.colab.output", "IPython.display" ], "chunk_id": "function_request_with_credentials_28f62a25" }, { "content": "def requests_with_progress(method: str, url: str, **kwargs) -> requests.Response:\n \"\"\"\n Make an HTTP request using the specified method and URL, with an optional progress bar.\n\n Args:\n method (str): The HTTP method to use (e.g. 'GET', 'POST').\n url (str): The URL to send the request to.\n **kwargs (Any): Additional keyword arguments to pass to the underlying `requests.request` function.\n\n Returns:\n (requests.Response): The response object from the HTTP request.\n\n Notes:\n - If 'progress' is set to True, the progress bar will display the download progress for responses with a known\n content length.\n - If 'progress' is a number then progress bar will display assuming content length = progress.\n \"\"\"\n progress = kwargs.pop(\"progress\", False)\n if not progress:\n return requests.request(method, url, **kwargs)\n response = requests.request(method, url, stream=True, **kwargs)\n total = int(response.headers.get(\"content-length\", 0) if isinstance(progress, bool) else progress) # total size\n try:\n pbar = TQDM(total=total, unit=\"B\", unit_scale=True, unit_divisor=1024)\n for data in response.iter_content(chunk_size=1024):\n pbar.update(len(data))\n pbar.close()\n except requests.exceptions.ChunkedEncodingError: # avoid 'Connection broken: IncompleteRead' warnings\n response.close()\n return response", "chunk_type": "function", "name": "requests_with_progress", "file_path": "ultralytics\\ultralytics\\hub\\utils.py", "start_line": 81, "end_line": 110, "start_col": 0, "end_col": 19, "parent_name": null, "docstring": "Make an HTTP request using the specified method and URL, with an optional progress bar.\n\nArgs:\n method (str): The HTTP method to use (e.g. 'GET', 'POST').\n url (str): The URL to send the request to.\n **kwargs (Any): Additional keyword arguments to pass to the underlying `requests.request` function.\n\nReturns:\n (requests.Response): The response object from the HTTP request.\n\nNotes:\n - If 'progress' is set to True, the progress bar will display the download progress for responses with a known\n content length.\n - If 'progress' is a number then progress bar will display assuming content length = progress.", "parameters": [ "method: str", "url: str" ], "return_type": "requests.Response", "decorators": [], "complexity_score": 4, "dependencies": [ "os", "random", "threading", "time", "pathlib.Path", "typing.Any", "typing.Optional", "requests", "ultralytics.__version__", "ultralytics.utils.ARGV", "ultralytics.utils.ENVIRONMENT", "ultralytics.utils.IS_COLAB", "ultralytics.utils.IS_GIT_DIR", "ultralytics.utils.IS_PIP_PACKAGE", "ultralytics.utils.LOGGER", "ultralytics.utils.ONLINE", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.RANK", "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.TQDM", "ultralytics.utils.TryExcept", "ultralytics.utils.colorstr", "ultralytics.utils.get_git_origin_url", "ultralytics.utils.downloads.GITHUB_ASSETS_NAMES", "ultralytics.utils.torch_utils.get_cpu_info", "google.colab.output", "IPython.display" ], "chunk_id": "function_requests_with_progress_78c116a1" }, { "content": "def smart_request(\n method: str,\n url: str,\n retry: int = 3,\n timeout: int = 30,\n thread: bool = True,\n code: int = -1,\n verbose: bool = True,\n progress: bool = False,\n **kwargs,\n) -> Optional[requests.Response]:\n \"\"\"\n Make an HTTP request using the 'requests' library, with exponential backoff retries up to a specified timeout.\n\n Args:\n method (str): The HTTP method to use for the request. Choices are 'post' and 'get'.\n url (str): The URL to make the request to.\n retry (int, optional): Number of retries to attempt before giving up.\n timeout (int, optional): Timeout in seconds after which the function will give up retrying.\n thread (bool, optional): Whether to execute the request in a separate daemon thread.\n code (int, optional): An identifier for the request, used for logging purposes.\n verbose (bool, optional): A flag to determine whether to print out to console or not.\n progress (bool, optional): Whether to show a progress bar during the request.\n **kwargs (Any): Keyword arguments to be passed to the requests function specified in method.\n\n Returns:\n (requests.Response | None): The HTTP response object. If the request is executed in a separate thread, returns\n None.\n \"\"\"\n retry_codes = (408, 500) # retry only these codes\n\n @TryExcept(verbose=verbose)\n def func(func_method, func_url, **func_kwargs):\n \"\"\"Make HTTP requests with retries and timeouts, with optional progress tracking.\"\"\"\n r = None # response\n t0 = time.time() # initial time for timer\n for i in range(retry + 1):\n if (time.time() - t0) > timeout:\n break\n r = requests_with_progress(func_method, func_url, **func_kwargs) # i.e. get(url, data, json, files)\n if r.status_code < 300: # return codes in the 2xx range are generally considered \"good\" or \"successful\"\n break\n try:\n m = r.json().get(\"message\", \"No JSON message.\")\n except AttributeError:\n m = \"Unable to read JSON.\"\n if i == 0:\n if r.status_code in retry_codes:\n m += f\" Retrying {retry}x for {timeout}s.\" if retry else \"\"\n elif r.status_code == 429: # rate limit\n h = r.headers # response headers\n m = (\n f\"Rate limit reached ({h['X-RateLimit-Remaining']}/{h['X-RateLimit-Limit']}). \"\n f\"Please retry after {h['Retry-After']}s.\"\n )\n if verbose:\n LOGGER.warning(f\"{PREFIX}{m} {HELP_MSG} ({r.status_code} #{code})\")\n if r.status_code not in retry_codes:\n return r\n time.sleep(2**i) # exponential standoff\n return r\n\n args = method, url\n kwargs[\"progress\"] = progress\n if thread:\n threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True).start()\n else:\n return func(*args, **kwargs)", "chunk_type": "function", "name": "smart_request", "file_path": "ultralytics\\ultralytics\\hub\\utils.py", "start_line": 113, "end_line": 180, "start_col": 0, "end_col": 36, "parent_name": null, "docstring": "Make an HTTP request using the 'requests' library, with exponential backoff retries up to a specified timeout.\n\nArgs:\n method (str): The HTTP method to use for the request. Choices are 'post' and 'get'.\n url (str): The URL to make the request to.\n retry (int, optional): Number of retries to attempt before giving up.\n timeout (int, optional): Timeout in seconds after which the function will give up retrying.\n thread (bool, optional): Whether to execute the request in a separate daemon thread.\n code (int, optional): An identifier for the request, used for logging purposes.\n verbose (bool, optional): A flag to determine whether to print out to console or not.\n progress (bool, optional): Whether to show a progress bar during the request.\n **kwargs (Any): Keyword arguments to be passed to the requests function specified in method.\n\nReturns:\n (requests.Response | None): The HTTP response object. If the request is executed in a separate thread, returns\n None.", "parameters": [ "method: str", "url: str", "retry: int", "timeout: int", "thread: bool", "code: int", "verbose: bool", "progress: bool" ], "return_type": "Optional[requests.Response]", "decorators": [], "complexity_score": 11, "dependencies": [ "os", "random", "threading", "time", "pathlib.Path", "typing.Any", "typing.Optional", "requests", "ultralytics.__version__", "ultralytics.utils.ARGV", "ultralytics.utils.ENVIRONMENT", "ultralytics.utils.IS_COLAB", "ultralytics.utils.IS_GIT_DIR", "ultralytics.utils.IS_PIP_PACKAGE", "ultralytics.utils.LOGGER", "ultralytics.utils.ONLINE", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.RANK", "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.TQDM", "ultralytics.utils.TryExcept", "ultralytics.utils.colorstr", "ultralytics.utils.get_git_origin_url", "ultralytics.utils.downloads.GITHUB_ASSETS_NAMES", "ultralytics.utils.torch_utils.get_cpu_info", "google.colab.output", "IPython.display" ], "chunk_id": "function_smart_request_9a153164" }, { "content": "class Events:\n \"\"\"\n A class for collecting anonymous event analytics.\n\n Event analytics are enabled when sync=True in settings and disabled when sync=False. Run 'yolo settings' to see and\n update settings.\n\n Attributes:\n url (str): The URL to send anonymous events.\n events (list): List of collected events to be sent.\n rate_limit (float): The rate limit in seconds for sending events.\n t (float): Rate limit timer in seconds.\n metadata (dict): A dictionary containing metadata about the environment.\n enabled (bool): A flag to enable or disable Events based on certain conditions.\n \"\"\"\n\n url = \"https://www.google-analytics.com/mp/collect?measurement_id=G-X8NCJYTQXM&api_secret=QLQrATrNSwGRFRLE-cbHJw\"\n\n def __init__(self):\n \"\"\"Initialize the Events object with default values for events, rate_limit, and metadata.\"\"\"\n self.events = [] # events list\n self.rate_limit = 30.0 # rate limit (seconds)\n self.t = 0.0 # rate limit timer (seconds)\n self.metadata = {\n \"cli\": Path(ARGV[0]).name == \"yolo\",\n \"install\": \"git\" if IS_GIT_DIR else \"pip\" if IS_PIP_PACKAGE else \"other\",\n \"python\": PYTHON_VERSION.rsplit(\".\", 1)[0], # i.e. 3.13\n \"CPU\": get_cpu_info(),\n # \"GPU\": get_gpu_info(index=0) if cuda else None,\n \"version\": __version__,\n \"env\": ENVIRONMENT,\n \"session_id\": round(random.random() * 1e15),\n \"engagement_time_msec\": 1000,\n }\n self.enabled = (\n SETTINGS[\"sync\"]\n and RANK in {-1, 0}\n and not TESTS_RUNNING\n and ONLINE\n and (IS_PIP_PACKAGE or get_git_origin_url() == \"https://github.com/ultralytics/ultralytics.git\")\n )\n\n def __call__(self, cfg, device=None):\n \"\"\"\n Attempt to add a new event to the events list and send events if the rate limit is reached.\n\n Args:\n cfg (IterableSimpleNamespace): The configuration object containing mode and task information.\n device (torch.device | str, optional): The device type (e.g., 'cpu', 'cuda').\n \"\"\"\n if not self.enabled:\n # Events disabled, do nothing\n return\n\n # Attempt to add to events\n if len(self.events) < 25: # Events list limited to 25 events (drop any events past this)\n params = {\n **self.metadata,\n \"task\": cfg.task,\n \"model\": cfg.model if cfg.model in GITHUB_ASSETS_NAMES else \"custom\",\n \"device\": str(device),\n }\n if cfg.mode == \"export\":\n params[\"format\"] = cfg.format\n self.events.append({\"name\": cfg.mode, \"params\": params})\n\n # Check rate limit\n t = time.time()\n if (t - self.t) < self.rate_limit:\n # Time is under rate limiter, wait to send\n return\n\n # Time is over rate limiter, send now\n data = {\"client_id\": SETTINGS[\"uuid\"], \"events\": self.events} # SHA-256 anonymized UUID hash and events list\n\n # POST equivalent to requests.post(self.url, json=data)\n smart_request(\"post\", self.url, json=data, retry=0, verbose=False)\n\n # Reset events and rate limit timer\n self.events = []\n self.t = t", "chunk_type": "class", "name": "Events", "file_path": "ultralytics\\ultralytics\\hub\\utils.py", "start_line": 183, "end_line": 263, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": "A class for collecting anonymous event analytics.\n\nEvent analytics are enabled when sync=True in settings and disabled when sync=False. Run 'yolo settings' to see and\nupdate settings.\n\nAttributes:\n url (str): The URL to send anonymous events.\n events (list): List of collected events to be sent.\n rate_limit (float): The rate limit in seconds for sending events.\n t (float): Rate limit timer in seconds.\n metadata (dict): A dictionary containing metadata about the environment.\n enabled (bool): A flag to enable or disable Events based on certain conditions.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "os", "random", "threading", "time", "pathlib.Path", "typing.Any", "typing.Optional", "requests", "ultralytics.__version__", "ultralytics.utils.ARGV", "ultralytics.utils.ENVIRONMENT", "ultralytics.utils.IS_COLAB", "ultralytics.utils.IS_GIT_DIR", "ultralytics.utils.IS_PIP_PACKAGE", "ultralytics.utils.LOGGER", "ultralytics.utils.ONLINE", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.RANK", "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.TQDM", "ultralytics.utils.TryExcept", "ultralytics.utils.colorstr", "ultralytics.utils.get_git_origin_url", "ultralytics.utils.downloads.GITHUB_ASSETS_NAMES", "ultralytics.utils.torch_utils.get_cpu_info", "google.colab.output", "IPython.display" ], "chunk_id": "class_Events_ae76603e" }, { "content": "events = Events()", "chunk_type": "variable", "name": "events", "file_path": "ultralytics\\ultralytics\\hub\\utils.py", "start_line": 267, "end_line": 267, "start_col": 0, "end_col": 17, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_events_a5d45cbf" }, { "content": "import requests", "chunk_type": "import", "name": "requests", "file_path": "ultralytics\\ultralytics\\hub\\__init__.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 15, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_requests_bc34ea6c" }, { "content": "from ultralytics.data.utils import HUBDatasetStats", "chunk_type": "import", "name": "HUBDatasetStats", "file_path": "ultralytics\\ultralytics\\hub\\__init__.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 50, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_HUBDatasetStats_59ce08f0" }, { "content": "from ultralytics.hub.auth import Auth", "chunk_type": "import", "name": "Auth", "file_path": "ultralytics\\ultralytics\\hub\\__init__.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 37, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Auth_775b9b50" }, { "content": "from ultralytics.hub.session import HUBTrainingSession", "chunk_type": "import", "name": "HUBTrainingSession", "file_path": "ultralytics\\ultralytics\\hub\\__init__.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 54, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_HUBTrainingSession_74c2d329" }, { "content": "from ultralytics.hub.utils import HUB_API_ROOT, HUB_WEB_ROOT, PREFIX, events", "chunk_type": "import", "name": "HUB_API_ROOT, HUB_WEB_ROOT, PREFIX, events", "file_path": "ultralytics\\ultralytics\\hub\\__init__.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 76, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_HUB_API_ROOT, HUB_WEB_ROOT, PREFIX, events_01401ddc" }, { "content": "from ultralytics.utils import LOGGER, SETTINGS, checks", "chunk_type": "import", "name": "LOGGER, SETTINGS, checks", "file_path": "ultralytics\\ultralytics\\hub\\__init__.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 54, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_LOGGER, SETTINGS, checks_ff8f5111" }, { "content": "__all__ = (\n \"PREFIX\",\n \"HUB_WEB_ROOT\",\n \"HUBTrainingSession\",\n \"login\",\n \"logout\",\n \"reset_model\",\n \"export_fmts_hub\",\n \"export_model\",\n \"get_export\",\n \"check_dataset\",\n \"events\",\n)", "chunk_type": "variable", "name": "__all__", "file_path": "ultralytics\\ultralytics\\hub\\__init__.py", "start_line": 11, "end_line": 23, "start_col": 0, "end_col": 1, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable___all___708bebdf" }, { "content": "def login(api_key: str = None, save: bool = True) -> bool:\n \"\"\"\n Log in to the Ultralytics HUB API using the provided API key.\n\n The session is not stored; a new session is created when needed using the saved SETTINGS or the HUB_API_KEY\n environment variable if successfully authenticated.\n\n Args:\n api_key (str, optional): API key to use for authentication. If not provided, it will be retrieved from\n SETTINGS or HUB_API_KEY environment variable.\n save (bool, optional): Whether to save the API key to SETTINGS if authentication is successful.\n\n Returns:\n (bool): True if authentication is successful, False otherwise.\n \"\"\"\n checks.check_requirements(\"hub-sdk>=0.0.12\")\n from hub_sdk import HUBClient\n\n api_key_url = f\"{HUB_WEB_ROOT}/settings?tab=api+keys\" # set the redirect URL\n saved_key = SETTINGS.get(\"api_key\")\n active_key = api_key or saved_key\n credentials = {\"api_key\": active_key} if active_key and active_key != \"\" else None # set credentials\n\n client = HUBClient(credentials) # initialize HUBClient\n\n if client.authenticated:\n # Successfully authenticated with HUB\n\n if save and client.api_key != saved_key:\n SETTINGS.update({\"api_key\": client.api_key}) # update settings with valid API key\n\n # Set message based on whether key was provided or retrieved from settings\n log_message = (\n \"New authentication successful ✅\" if client.api_key == api_key or not credentials else \"Authenticated ✅\"\n )\n LOGGER.info(f\"{PREFIX}{log_message}\")\n\n return True\n else:\n # Failed to authenticate with HUB\n LOGGER.info(f\"{PREFIX}Get API key from {api_key_url} and then run 'yolo login API_KEY'\")\n return False", "chunk_type": "function", "name": "login", "file_path": "ultralytics\\ultralytics\\hub\\__init__.py", "start_line": 26, "end_line": 67, "start_col": 0, "end_col": 20, "parent_name": null, "docstring": "Log in to the Ultralytics HUB API using the provided API key.\n\nThe session is not stored; a new session is created when needed using the saved SETTINGS or the HUB_API_KEY\nenvironment variable if successfully authenticated.\n\nArgs:\n api_key (str, optional): API key to use for authentication. If not provided, it will be retrieved from\n SETTINGS or HUB_API_KEY environment variable.\n save (bool, optional): Whether to save the API key to SETTINGS if authentication is successful.\n\nReturns:\n (bool): True if authentication is successful, False otherwise.", "parameters": [ "api_key: str", "save: bool" ], "return_type": "bool", "decorators": [], "complexity_score": 3, "dependencies": [ "requests", "ultralytics.data.utils.HUBDatasetStats", "ultralytics.hub.auth.Auth", "ultralytics.hub.session.HUBTrainingSession", "ultralytics.hub.utils.HUB_API_ROOT", "ultralytics.hub.utils.HUB_WEB_ROOT", "ultralytics.hub.utils.PREFIX", "ultralytics.hub.utils.events", "ultralytics.utils.LOGGER", "ultralytics.utils.SETTINGS", "ultralytics.utils.checks", "hub_sdk.HUBClient", "ultralytics.engine.exporter.export_formats" ], "chunk_id": "function_login_934a779d" }, { "content": "def logout():\n \"\"\"Log out of Ultralytics HUB by removing the API key from the settings file.\"\"\"\n SETTINGS[\"api_key\"] = \"\"\n LOGGER.info(f\"{PREFIX}logged out ✅. To log in again, use 'yolo login'.\")", "chunk_type": "function", "name": "logout", "file_path": "ultralytics\\ultralytics\\hub\\__init__.py", "start_line": 70, "end_line": 73, "start_col": 0, "end_col": 78, "parent_name": null, "docstring": "Log out of Ultralytics HUB by removing the API key from the settings file.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "requests", "ultralytics.data.utils.HUBDatasetStats", "ultralytics.hub.auth.Auth", "ultralytics.hub.session.HUBTrainingSession", "ultralytics.hub.utils.HUB_API_ROOT", "ultralytics.hub.utils.HUB_WEB_ROOT", "ultralytics.hub.utils.PREFIX", "ultralytics.hub.utils.events", "ultralytics.utils.LOGGER", "ultralytics.utils.SETTINGS", "ultralytics.utils.checks", "hub_sdk.HUBClient", "ultralytics.engine.exporter.export_formats" ], "chunk_id": "function_logout_6e0abd8b" }, { "content": "def reset_model(model_id: str = \"\"):\n \"\"\"Reset a trained model to an untrained state.\"\"\"\n r = requests.post(f\"{HUB_API_ROOT}/model-reset\", json={\"modelId\": model_id}, headers={\"x-api-key\": Auth().api_key})\n if r.status_code == 200:\n LOGGER.info(f\"{PREFIX}Model reset successfully\")\n return\n LOGGER.warning(f\"{PREFIX}Model reset failure {r.status_code} {r.reason}\")", "chunk_type": "function", "name": "reset_model", "file_path": "ultralytics\\ultralytics\\hub\\__init__.py", "start_line": 76, "end_line": 82, "start_col": 0, "end_col": 77, "parent_name": null, "docstring": "Reset a trained model to an untrained state.", "parameters": [ "model_id: str" ], "return_type": null, "decorators": [], "complexity_score": 2, "dependencies": [ "requests", "ultralytics.data.utils.HUBDatasetStats", "ultralytics.hub.auth.Auth", "ultralytics.hub.session.HUBTrainingSession", "ultralytics.hub.utils.HUB_API_ROOT", "ultralytics.hub.utils.HUB_WEB_ROOT", "ultralytics.hub.utils.PREFIX", "ultralytics.hub.utils.events", "ultralytics.utils.LOGGER", "ultralytics.utils.SETTINGS", "ultralytics.utils.checks", "hub_sdk.HUBClient", "ultralytics.engine.exporter.export_formats" ], "chunk_id": "function_reset_model_f6cdd4f2" }, { "content": "def export_fmts_hub():\n \"\"\"Return a list of HUB-supported export formats.\"\"\"\n from ultralytics.engine.exporter import export_formats\n\n return list(export_formats()[\"Argument\"][1:]) + [\"ultralytics_tflite\", \"ultralytics_coreml\"]", "chunk_type": "function", "name": "export_fmts_hub", "file_path": "ultralytics\\ultralytics\\hub\\__init__.py", "start_line": 85, "end_line": 89, "start_col": 0, "end_col": 96, "parent_name": null, "docstring": "Return a list of HUB-supported export formats.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "requests", "ultralytics.data.utils.HUBDatasetStats", "ultralytics.hub.auth.Auth", "ultralytics.hub.session.HUBTrainingSession", "ultralytics.hub.utils.HUB_API_ROOT", "ultralytics.hub.utils.HUB_WEB_ROOT", "ultralytics.hub.utils.PREFIX", "ultralytics.hub.utils.events", "ultralytics.utils.LOGGER", "ultralytics.utils.SETTINGS", "ultralytics.utils.checks", "hub_sdk.HUBClient", "ultralytics.engine.exporter.export_formats" ], "chunk_id": "function_export_fmts_hub_3977e3b9" }, { "content": "def export_model(model_id: str = \"\", format: str = \"torchscript\"):\n \"\"\"\n Export a model to a specified format for deployment via the Ultralytics HUB API.\n\n Args:\n model_id (str): The ID of the model to export. An empty string will use the default model.\n format (str): The format to export the model to. Must be one of the supported formats returned by\n export_fmts_hub().\n\n Raises:\n AssertionError: If the specified format is not supported or if the export request fails.\n\n Examples:\n >>> from ultralytics import hub\n >>> hub.export_model(model_id=\"your_model_id\", format=\"torchscript\")\n \"\"\"\n assert format in export_fmts_hub(), f\"Unsupported export format '{format}', valid formats are {export_fmts_hub()}\"\n r = requests.post(\n f\"{HUB_API_ROOT}/v1/models/{model_id}/export\", json={\"format\": format}, headers={\"x-api-key\": Auth().api_key}\n )\n assert r.status_code == 200, f\"{PREFIX}{format} export failure {r.status_code} {r.reason}\"\n LOGGER.info(f\"{PREFIX}{format} export started ✅\")", "chunk_type": "function", "name": "export_model", "file_path": "ultralytics\\ultralytics\\hub\\__init__.py", "start_line": 92, "end_line": 113, "start_col": 0, "end_col": 55, "parent_name": null, "docstring": "Export a model to a specified format for deployment via the Ultralytics HUB API.\n\nArgs:\n model_id (str): The ID of the model to export. An empty string will use the default model.\n format (str): The format to export the model to. Must be one of the supported formats returned by\n export_fmts_hub().\n\nRaises:\n AssertionError: If the specified format is not supported or if the export request fails.\n\nExamples:\n >>> from ultralytics import hub\n >>> hub.export_model(model_id=\"your_model_id\", format=\"torchscript\")", "parameters": [ "model_id: str", "format: str" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "requests", "ultralytics.data.utils.HUBDatasetStats", "ultralytics.hub.auth.Auth", "ultralytics.hub.session.HUBTrainingSession", "ultralytics.hub.utils.HUB_API_ROOT", "ultralytics.hub.utils.HUB_WEB_ROOT", "ultralytics.hub.utils.PREFIX", "ultralytics.hub.utils.events", "ultralytics.utils.LOGGER", "ultralytics.utils.SETTINGS", "ultralytics.utils.checks", "hub_sdk.HUBClient", "ultralytics.engine.exporter.export_formats" ], "chunk_id": "function_export_model_af03d365" }, { "content": "def get_export(model_id: str = \"\", format: str = \"torchscript\"):\n \"\"\"\n Retrieve an exported model in the specified format from Ultralytics HUB using the model ID.\n\n Args:\n model_id (str): The ID of the model to retrieve from Ultralytics HUB.\n format (str): The export format to retrieve. Must be one of the supported formats returned by\n export_fmts_hub().\n\n Returns:\n (dict): JSON response containing the exported model information.\n\n Raises:\n AssertionError: If the specified format is not supported or if the API request fails.\n\n Examples:\n >>> from ultralytics import hub\n >>> result = hub.get_export(model_id=\"your_model_id\", format=\"torchscript\")\n \"\"\"\n assert format in export_fmts_hub(), f\"Unsupported export format '{format}', valid formats are {export_fmts_hub()}\"\n r = requests.post(\n f\"{HUB_API_ROOT}/get-export\",\n json={\"apiKey\": Auth().api_key, \"modelId\": model_id, \"format\": format},\n headers={\"x-api-key\": Auth().api_key},\n )\n assert r.status_code == 200, f\"{PREFIX}{format} get_export failure {r.status_code} {r.reason}\"\n return r.json()", "chunk_type": "function", "name": "get_export", "file_path": "ultralytics\\ultralytics\\hub\\__init__.py", "start_line": 116, "end_line": 142, "start_col": 0, "end_col": 19, "parent_name": null, "docstring": "Retrieve an exported model in the specified format from Ultralytics HUB using the model ID.\n\nArgs:\n model_id (str): The ID of the model to retrieve from Ultralytics HUB.\n format (str): The export format to retrieve. Must be one of the supported formats returned by\n export_fmts_hub().\n\nReturns:\n (dict): JSON response containing the exported model information.\n\nRaises:\n AssertionError: If the specified format is not supported or if the API request fails.\n\nExamples:\n >>> from ultralytics import hub\n >>> result = hub.get_export(model_id=\"your_model_id\", format=\"torchscript\")", "parameters": [ "model_id: str", "format: str" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "requests", "ultralytics.data.utils.HUBDatasetStats", "ultralytics.hub.auth.Auth", "ultralytics.hub.session.HUBTrainingSession", "ultralytics.hub.utils.HUB_API_ROOT", "ultralytics.hub.utils.HUB_WEB_ROOT", "ultralytics.hub.utils.PREFIX", "ultralytics.hub.utils.events", "ultralytics.utils.LOGGER", "ultralytics.utils.SETTINGS", "ultralytics.utils.checks", "hub_sdk.HUBClient", "ultralytics.engine.exporter.export_formats" ], "chunk_id": "function_get_export_e81421af" }, { "content": "def check_dataset(path: str, task: str) -> None:\n \"\"\"\n Check HUB dataset Zip file for errors before upload.\n\n Args:\n path (str): Path to data.zip (with data.yaml inside data.zip).\n task (str): Dataset task. Options are 'detect', 'segment', 'pose', 'classify', 'obb'.\n\n Examples:\n >>> from ultralytics.hub import check_dataset\n >>> check_dataset(\"path/to/coco8.zip\", task=\"detect\") # detect dataset\n >>> check_dataset(\"path/to/coco8-seg.zip\", task=\"segment\") # segment dataset\n >>> check_dataset(\"path/to/coco8-pose.zip\", task=\"pose\") # pose dataset\n >>> check_dataset(\"path/to/dota8.zip\", task=\"obb\") # OBB dataset\n >>> check_dataset(\"path/to/imagenet10.zip\", task=\"classify\") # classification dataset\n\n Notes:\n Download *.zip files from https://github.com/ultralytics/hub/tree/main/example_datasets\n i.e. https://github.com/ultralytics/hub/raw/main/example_datasets/coco8.zip for coco8.zip.\n \"\"\"\n HUBDatasetStats(path=path, task=task).get_json()\n LOGGER.info(f\"Checks completed correctly ✅. Upload this dataset to {HUB_WEB_ROOT}/datasets/.\")", "chunk_type": "function", "name": "check_dataset", "file_path": "ultralytics\\ultralytics\\hub\\__init__.py", "start_line": 145, "end_line": 166, "start_col": 0, "end_col": 100, "parent_name": null, "docstring": "Check HUB dataset Zip file for errors before upload.\n\nArgs:\n path (str): Path to data.zip (with data.yaml inside data.zip).\n task (str): Dataset task. Options are 'detect', 'segment', 'pose', 'classify', 'obb'.\n\nExamples:\n >>> from ultralytics.hub import check_dataset\n >>> check_dataset(\"path/to/coco8.zip\", task=\"detect\") # detect dataset\n >>> check_dataset(\"path/to/coco8-seg.zip\", task=\"segment\") # segment dataset\n >>> check_dataset(\"path/to/coco8-pose.zip\", task=\"pose\") # pose dataset\n >>> check_dataset(\"path/to/dota8.zip\", task=\"obb\") # OBB dataset\n >>> check_dataset(\"path/to/imagenet10.zip\", task=\"classify\") # classification dataset\n\nNotes:\n Download *.zip files from https://github.com/ultralytics/hub/tree/main/example_datasets\n i.e. https://github.com/ultralytics/hub/raw/main/example_datasets/coco8.zip for coco8.zip.", "parameters": [ "path: str", "task: str" ], "return_type": "None", "decorators": [], "complexity_score": 1, "dependencies": [ "requests", "ultralytics.data.utils.HUBDatasetStats", "ultralytics.hub.auth.Auth", "ultralytics.hub.session.HUBTrainingSession", "ultralytics.hub.utils.HUB_API_ROOT", "ultralytics.hub.utils.HUB_WEB_ROOT", "ultralytics.hub.utils.PREFIX", "ultralytics.hub.utils.events", "ultralytics.utils.LOGGER", "ultralytics.utils.SETTINGS", "ultralytics.utils.checks", "hub_sdk.HUBClient", "ultralytics.engine.exporter.export_formats" ], "chunk_id": "function_check_dataset_cb9bc8b7" }, { "content": "from .fastsam import FastSAM", "chunk_type": "import", "name": "FastSAM", "file_path": "ultralytics\\ultralytics\\models\\__init__.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 28, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_FastSAM_8c78c7c3" }, { "content": "from .nas import NAS", "chunk_type": "import", "name": "NAS", "file_path": "ultralytics\\ultralytics\\models\\__init__.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 20, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_NAS_4176bad0" }, { "content": "from .rtdetr import RTDETR", "chunk_type": "import", "name": "RTDETR", "file_path": "ultralytics\\ultralytics\\models\\__init__.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 26, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_RTDETR_c824f4d1" }, { "content": "from .sam import SAM", "chunk_type": "import", "name": "SAM", "file_path": "ultralytics\\ultralytics\\models\\__init__.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 20, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_SAM_81acb123" }, { "content": "from .yolo import YOLO, YOLOE, YOLOWorld", "chunk_type": "import", "name": "YOLO, YOLOE, YOLOWorld", "file_path": "ultralytics\\ultralytics\\models\\__init__.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 40, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_YOLO, YOLOE, YOLOWorld_0fa34b0f" }, { "content": "__all__ = \"YOLO\", \"RTDETR\", \"SAM\", \"FastSAM\", \"NAS\", \"YOLOWorld\", \"YOLOE\" # allow simpler import", "chunk_type": "variable", "name": "__all__", "file_path": "ultralytics\\ultralytics\\models\\__init__.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 73, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable___all___9b085dd3" }, { "content": "import ast", "chunk_type": "import", "name": "ast", "file_path": "ultralytics\\ultralytics\\nn\\autobackend.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 10, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_ast_f35612d6" }, { "content": "import json", "chunk_type": "import", "name": "json", "file_path": "ultralytics\\ultralytics\\nn\\autobackend.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 11, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_json_9c66a03f" }, { "content": "import platform", "chunk_type": "import", "name": "platform", "file_path": "ultralytics\\ultralytics\\nn\\autobackend.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 15, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_platform_121dba93" }, { "content": "import zipfile", "chunk_type": "import", "name": "zipfile", "file_path": "ultralytics\\ultralytics\\nn\\autobackend.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 14, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_zipfile_511d30d5" }, { "content": "from collections import OrderedDict, namedtuple", "chunk_type": "import", "name": "OrderedDict, namedtuple", "file_path": "ultralytics\\ultralytics\\nn\\autobackend.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 47, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_OrderedDict, namedtuple_17ee1e05" }, { "content": "from pathlib import Path", "chunk_type": "import", "name": "Path", "file_path": "ultralytics\\ultralytics\\nn\\autobackend.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Path_6dca160b" }, { "content": "from typing import Any, Dict, List, Optional, Tuple, Union", "chunk_type": "import", "name": "Any, Dict, List, Optional, Tuple, Union", "file_path": "ultralytics\\ultralytics\\nn\\autobackend.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 58, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Any, Dict, List, Optional, Tuple, Union_5b2f133a" }, { "content": "import cv2", "chunk_type": "import", "name": "cv2", "file_path": "ultralytics\\ultralytics\\nn\\autobackend.py", "start_line": 11, "end_line": 11, "start_col": 0, "end_col": 10, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_cv2_479bdd9c" }, { "content": "import numpy as np", "chunk_type": "import", "name": "numpy", "file_path": "ultralytics\\ultralytics\\nn\\autobackend.py", "start_line": 12, "end_line": 12, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_numpy_c73a4586" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\ultralytics\\nn\\autobackend.py", "start_line": 13, "end_line": 13, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_7b2c5672" }, { "content": "import torch.nn as nn", "chunk_type": "import", "name": "torch.nn", "file_path": "ultralytics\\ultralytics\\nn\\autobackend.py", "start_line": 14, "end_line": 14, "start_col": 0, "end_col": 21, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch.nn_ce6432c4" }, { "content": "from PIL import Image", "chunk_type": "import", "name": "Image", "file_path": "ultralytics\\ultralytics\\nn\\autobackend.py", "start_line": 15, "end_line": 15, "start_col": 0, "end_col": 21, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Image_cb52f1fd" }, { "content": "from ultralytics.utils import ARM64, IS_JETSON, LINUX, LOGGER, PYTHON_VERSION, ROOT, YAML", "chunk_type": "import", "name": "ARM64, IS_JETSON, LINUX, LOGGER, PYTHON_VERSION, ROOT, YAML", "file_path": "ultralytics\\ultralytics\\nn\\autobackend.py", "start_line": 17, "end_line": 17, "start_col": 0, "end_col": 89, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_ARM64, IS_JETSON, LINUX, LOGGER, PYTHON_VERSION, ROOT, YAML_e8d04163" }, { "content": "from ultralytics.utils.checks import check_requirements, check_suffix, check_version, check_yaml, is_rockchip", "chunk_type": "import", "name": "check_requirements, check_suffix, check_version, check_yaml, is_rockchip", "file_path": "ultralytics\\ultralytics\\nn\\autobackend.py", "start_line": 18, "end_line": 18, "start_col": 0, "end_col": 109, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_check_requirements, check_suffix, check_version, check_yaml, is_rockchip_f17fdfcc" }, { "content": "from ultralytics.utils.downloads import attempt_download_asset, is_url", "chunk_type": "import", "name": "attempt_download_asset, is_url", "file_path": "ultralytics\\ultralytics\\nn\\autobackend.py", "start_line": 19, "end_line": 19, "start_col": 0, "end_col": 70, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_attempt_download_asset, is_url_82b27611" }, { "content": "def check_class_names(names: Union[List, Dict]) -> Dict[int, str]:\n \"\"\"\n Check class names and convert to dict format if needed.\n\n Args:\n names (list | dict): Class names as list or dict format.\n\n Returns:\n (dict): Class names in dict format with integer keys and string values.\n\n Raises:\n KeyError: If class indices are invalid for the dataset size.\n \"\"\"\n if isinstance(names, list): # names is a list\n names = dict(enumerate(names)) # convert to dict\n if isinstance(names, dict):\n # Convert 1) string keys to int, i.e. '0' to 0, and non-string values to strings, i.e. True to 'True'\n names = {int(k): str(v) for k, v in names.items()}\n n = len(names)\n if max(names.keys()) >= n:\n raise KeyError(\n f\"{n}-class dataset requires class indices 0-{n - 1}, but you have invalid class indices \"\n f\"{min(names.keys())}-{max(names.keys())} defined in your dataset YAML.\"\n )\n if isinstance(names[0], str) and names[0].startswith(\"n0\"): # imagenet class codes, i.e. 'n01440764'\n names_map = YAML.load(ROOT / \"cfg/datasets/ImageNet.yaml\")[\"map\"] # human-readable names\n names = {k: names_map[v] for k, v in names.items()}\n return names", "chunk_type": "function", "name": "check_class_names", "file_path": "ultralytics\\ultralytics\\nn\\autobackend.py", "start_line": 22, "end_line": 49, "start_col": 0, "end_col": 16, "parent_name": null, "docstring": "Check class names and convert to dict format if needed.\n\nArgs:\n names (list | dict): Class names as list or dict format.\n\nReturns:\n (dict): Class names in dict format with integer keys and string values.\n\nRaises:\n KeyError: If class indices are invalid for the dataset size.", "parameters": [ "names: Union[List, Dict]" ], "return_type": "Dict[int, str]", "decorators": [], "complexity_score": 7, "dependencies": [ "ast", "json", "platform", "zipfile", "collections.OrderedDict", "collections.namedtuple", "pathlib.Path", "typing.Any", "typing.Dict", "typing.List", "typing.Optional", "typing.Tuple", "typing.Union", "cv2", "numpy", "torch", "torch.nn", "PIL.Image", "ultralytics.utils.ARM64", "ultralytics.utils.IS_JETSON", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.ROOT", "ultralytics.utils.YAML", "ultralytics.utils.checks.check_requirements", "ultralytics.utils.checks.check_suffix", "ultralytics.utils.checks.check_version", "ultralytics.utils.checks.check_yaml", "ultralytics.utils.checks.is_rockchip", "ultralytics.utils.downloads.attempt_download_asset", "ultralytics.utils.downloads.is_url", "torchvision", "ultralytics.engine.exporter.export_formats", "urllib.parse.urlsplit", "ultralytics.nn.tasks.attempt_load_weights", "torchvision", "onnxruntime", "mct_quantizers", "sony_custom_layers.pytorch.nms.nms_ort", "openvino", "tensorrt", "coremltools", "tensorrt", "tensorflow", "tensorflow", "ultralytics.engine.exporter.gd_outputs", "tflite_runtime.interpreter.Interpreter", "tflite_runtime.interpreter.load_delegate", "tensorflow", "paddle.inference", "os", "MNN", "ncnn", "ultralytics.utils.triton.TritonRemoteModel", "rknnlite.api.RKNNLite", "ultralytics.engine.exporter.export_formats" ], "chunk_id": "function_check_class_names_8c4db698" }, { "content": "def default_class_names(data: Optional[Union[str, Path]] = None) -> Dict[int, str]:\n \"\"\"\n Apply default class names to an input YAML file or return numerical class names.\n\n Args:\n data (str | Path, optional): Path to YAML file containing class names.\n\n Returns:\n (dict): Dictionary mapping class indices to class names.\n \"\"\"\n if data:\n try:\n return YAML.load(check_yaml(data))[\"names\"]\n except Exception:\n pass\n return {i: f\"class{i}\" for i in range(999)} # return default if above errors", "chunk_type": "function", "name": "default_class_names", "file_path": "ultralytics\\ultralytics\\nn\\autobackend.py", "start_line": 52, "end_line": 67, "start_col": 0, "end_col": 47, "parent_name": null, "docstring": "Apply default class names to an input YAML file or return numerical class names.\n\nArgs:\n data (str | Path, optional): Path to YAML file containing class names.\n\nReturns:\n (dict): Dictionary mapping class indices to class names.", "parameters": [ "data: Optional[Union[str, Path]]" ], "return_type": "Dict[int, str]", "decorators": [], "complexity_score": 4, "dependencies": [ "ast", "json", "platform", "zipfile", "collections.OrderedDict", "collections.namedtuple", "pathlib.Path", "typing.Any", "typing.Dict", "typing.List", "typing.Optional", "typing.Tuple", "typing.Union", "cv2", "numpy", "torch", "torch.nn", "PIL.Image", "ultralytics.utils.ARM64", "ultralytics.utils.IS_JETSON", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.ROOT", "ultralytics.utils.YAML", "ultralytics.utils.checks.check_requirements", "ultralytics.utils.checks.check_suffix", "ultralytics.utils.checks.check_version", "ultralytics.utils.checks.check_yaml", "ultralytics.utils.checks.is_rockchip", "ultralytics.utils.downloads.attempt_download_asset", "ultralytics.utils.downloads.is_url", "torchvision", "ultralytics.engine.exporter.export_formats", "urllib.parse.urlsplit", "ultralytics.nn.tasks.attempt_load_weights", "torchvision", "onnxruntime", "mct_quantizers", "sony_custom_layers.pytorch.nms.nms_ort", "openvino", "tensorrt", "coremltools", "tensorrt", "tensorflow", "tensorflow", "ultralytics.engine.exporter.gd_outputs", "tflite_runtime.interpreter.Interpreter", "tflite_runtime.interpreter.load_delegate", "tensorflow", "paddle.inference", "os", "MNN", "ncnn", "ultralytics.utils.triton.TritonRemoteModel", "rknnlite.api.RKNNLite", "ultralytics.engine.exporter.export_formats" ], "chunk_id": "function_default_class_names_1a7aa65a" }, { "content": "class AutoBackend(nn.Module):\n \"\"\"\n Handle dynamic backend selection for running inference using Ultralytics YOLO models.\n\n The AutoBackend class is designed to provide an abstraction layer for various inference engines. It supports a wide\n range of formats, each with specific naming conventions as outlined below:\n\n Supported Formats and Naming Conventions:\n | Format | File Suffix |\n | --------------------- | ----------------- |\n | PyTorch | *.pt |\n | TorchScript | *.torchscript |\n | ONNX Runtime | *.onnx |\n | ONNX OpenCV DNN | *.onnx (dnn=True) |\n | OpenVINO | *openvino_model/ |\n | CoreML | *.mlpackage |\n | TensorRT | *.engine |\n | TensorFlow SavedModel | *_saved_model/ |\n | TensorFlow GraphDef | *.pb |\n | TensorFlow Lite | *.tflite |\n | TensorFlow Edge TPU | *_edgetpu.tflite |\n | PaddlePaddle | *_paddle_model/ |\n | MNN | *.mnn |\n | NCNN | *_ncnn_model/ |\n | IMX | *_imx_model/ |\n | RKNN | *_rknn_model/ |\n\n Attributes:\n model (torch.nn.Module): The loaded YOLO model.\n device (torch.device): The device (CPU or GPU) on which the model is loaded.\n task (str): The type of task the model performs (detect, segment, classify, pose).\n names (dict): A dictionary of class names that the model can detect.\n stride (int): The model stride, typically 32 for YOLO models.\n fp16 (bool): Whether the model uses half-precision (FP16) inference.\n nhwc (bool): Whether the model expects NHWC input format instead of NCHW.\n pt (bool): Whether the model is a PyTorch model.\n jit (bool): Whether the model is a TorchScript model.\n onnx (bool): Whether the model is an ONNX model.\n xml (bool): Whether the model is an OpenVINO model.\n engine (bool): Whether the model is a TensorRT engine.\n coreml (bool): Whether the model is a CoreML model.\n saved_model (bool): Whether the model is a TensorFlow SavedModel.\n pb (bool): Whether the model is a TensorFlow GraphDef.\n tflite (bool): Whether the model is a TensorFlow Lite model.\n edgetpu (bool): Whether the model is a TensorFlow Edge TPU model.\n tfjs (bool): Whether the model is a TensorFlow.js model.\n paddle (bool): Whether the model is a PaddlePaddle model.\n mnn (bool): Whether the model is an MNN model.\n ncnn (bool): Whether the model is an NCNN model.\n imx (bool): Whether the model is an IMX model.\n rknn (bool): Whether the model is an RKNN model.\n triton (bool): Whether the model is a Triton Inference Server model.\n\n Methods:\n forward: Run inference on an input image.\n from_numpy: Convert numpy array to tensor.\n warmup: Warm up the model with a dummy input.\n _model_type: Determine the model type from file path.\n\n Examples:\n >>> model = AutoBackend(weights=\"yolo11n.pt\", device=\"cuda\")\n >>> results = model(img)\n \"\"\"\n\n @torch.no_grad()\n def __init__(\n self,\n weights: Union[str, List[str], torch.nn.Module] = \"yolo11n.pt\",\n device: torch.device = torch.device(\"cpu\"),\n dnn: bool = False,\n data: Optional[Union[str, Path]] = None,\n fp16: bool = False,\n batch: int = 1,\n fuse: bool = True,\n verbose: bool = True,\n ):\n \"\"\"\n Initialize the AutoBackend for inference.\n\n Args:\n weights (str | List[str] | torch.nn.Module): Path to the model weights file or a module instance.\n device (torch.device): Device to run the model on.\n dnn (bool): Use OpenCV DNN module for ONNX inference.\n data (str | Path, optional): Path to the additional data.yaml file containing class names.\n fp16 (bool): Enable half-precision inference. Supported only on specific backends.\n batch (int): Batch-size to assume for inference.\n fuse (bool): Fuse Conv2D + BatchNorm layers for optimization.\n verbose (bool): Enable verbose logging.\n \"\"\"\n super().__init__()\n w = str(weights[0] if isinstance(weights, list) else weights)\n nn_module = isinstance(weights, torch.nn.Module)\n (\n pt,\n jit,\n onnx,\n xml,\n engine,\n coreml,\n saved_model,\n pb,\n tflite,\n edgetpu,\n tfjs,\n paddle,\n mnn,\n ncnn,\n imx,\n rknn,\n triton,\n ) = self._model_type(w)\n fp16 &= pt or jit or onnx or xml or engine or nn_module or triton # FP16\n nhwc = coreml or saved_model or pb or tflite or edgetpu or rknn # BHWC formats (vs torch BCWH)\n stride, ch = 32, 3 # default stride and channels\n end2end, dynamic = False, False\n model, metadata, task = None, None, None\n\n # Set device\n cuda = isinstance(device, torch.device) and torch.cuda.is_available() and device.type != \"cpu\" # use CUDA\n if cuda and not any([nn_module, pt, jit, engine, onnx, paddle]): # GPU dataloader formats\n device = torch.device(\"cpu\")\n cuda = False\n\n # Download if not local\n if not (pt or triton or nn_module):\n w = attempt_download_asset(w)\n\n # In-memory PyTorch model\n if nn_module:\n model = weights.to(device)\n if fuse:\n model = model.fuse(verbose=verbose)\n if hasattr(model, \"kpt_shape\"):\n kpt_shape = model.kpt_shape # pose-only\n stride = max(int(model.stride.max()), 32) # model stride\n names = model.module.names if hasattr(model, \"module\") else model.names # get class names\n model.half() if fp16 else model.float()\n ch = model.yaml.get(\"channels\", 3)\n self.model = model # explicitly assign for to(), cpu(), cuda(), half()\n pt = True\n\n # PyTorch\n elif pt:\n from ultralytics.nn.tasks import attempt_load_weights\n\n model = attempt_load_weights(\n weights if isinstance(weights, list) else w, device=device, inplace=True, fuse=fuse\n )\n if hasattr(model, \"kpt_shape\"):\n kpt_shape = model.kpt_shape # pose-only\n stride = max(int(model.stride.max()), 32) # model stride\n names = model.module.names if hasattr(model, \"module\") else model.names # get class names\n model.half() if fp16 else model.float()\n ch = model.yaml.get(\"channels\", 3)\n self.model = model # explicitly assign for to(), cpu(), cuda(), half()\n\n # TorchScript\n elif jit:\n import torchvision # noqa - https://github.com/ultralytics/ultralytics/pull/19747\n\n LOGGER.info(f\"Loading {w} for TorchScript inference...\")\n extra_files = {\"config.txt\": \"\"} # model metadata\n model = torch.jit.load(w, _extra_files=extra_files, map_location=device)\n model.half() if fp16 else model.float()\n if extra_files[\"config.txt\"]: # load metadata dict\n metadata = json.loads(extra_files[\"config.txt\"], object_hook=lambda x: dict(x.items()))\n\n # ONNX OpenCV DNN\n elif dnn:\n LOGGER.info(f\"Loading {w} for ONNX OpenCV DNN inference...\")\n check_requirements(\"opencv-python>=4.5.4\")\n net = cv2.dnn.readNetFromONNX(w)\n\n # ONNX Runtime and IMX\n elif onnx or imx:\n LOGGER.info(f\"Loading {w} for ONNX Runtime inference...\")\n check_requirements((\"onnx\", \"onnxruntime-gpu\" if cuda else \"onnxruntime\"))\n import onnxruntime\n\n providers = [\"CPUExecutionProvider\"]\n if cuda:\n if \"CUDAExecutionProvider\" in onnxruntime.get_available_providers():\n providers.insert(0, \"CUDAExecutionProvider\")\n else: # Only log warning if CUDA was requested but unavailable\n LOGGER.warning(\"Failed to start ONNX Runtime with CUDA. Using CPU...\")\n device = torch.device(\"cpu\")\n cuda = False\n LOGGER.info(f\"Using ONNX Runtime {providers[0]}\")\n if onnx:\n session = onnxruntime.InferenceSession(w, providers=providers)\n else:\n check_requirements(\n [\"model-compression-toolkit>=2.4.1\", \"sony-custom-layers[torch]>=0.3.0\", \"onnxruntime-extensions\"]\n )\n w = next(Path(w).glob(\"*.onnx\"))\n LOGGER.info(f\"Loading {w} for ONNX IMX inference...\")\n import mct_quantizers as mctq\n from sony_custom_layers.pytorch.nms import nms_ort # noqa\n\n session_options = mctq.get_ort_session_options()\n session_options.enable_mem_reuse = False # fix the shape mismatch from onnxruntime\n session = onnxruntime.InferenceSession(w, session_options, providers=[\"CPUExecutionProvider\"])\n\n output_names = [x.name for x in session.get_outputs()]\n metadata = session.get_modelmeta().custom_metadata_map\n dynamic = isinstance(session.get_outputs()[0].shape[0], str)\n fp16 = \"float16\" in session.get_inputs()[0].type\n if not dynamic:\n io = session.io_binding()\n bindings = []\n for output in session.get_outputs():\n out_fp16 = \"float16\" in output.type\n y_tensor = torch.empty(output.shape, dtype=torch.float16 if out_fp16 else torch.float32).to(device)\n io.bind_output(\n name=output.name,\n device_type=device.type,\n device_id=device.index if cuda else 0,\n element_type=np.float16 if out_fp16 else np.float32,\n shape=tuple(y_tensor.shape),\n buffer_ptr=y_tensor.data_ptr(),\n )\n bindings.append(y_tensor)\n\n # OpenVINO\n elif xml:\n LOGGER.info(f\"Loading {w} for OpenVINO inference...\")\n check_requirements(\"openvino>=2024.0.0\")\n import openvino as ov\n\n core = ov.Core()\n device_name = \"AUTO\"\n if isinstance(device, str) and device.startswith(\"intel\"):\n device_name = device.split(\":\")[1].upper() # Intel OpenVINO device\n device = torch.device(\"cpu\")\n if device_name not in core.available_devices:\n LOGGER.warning(f\"OpenVINO device '{device_name}' not available. Using 'AUTO' instead.\")\n device_name = \"AUTO\"\n w = Path(w)\n if not w.is_file(): # if not *.xml\n w = next(w.glob(\"*.xml\")) # get *.xml file from *_openvino_model dir\n ov_model = core.read_model(model=str(w), weights=w.with_suffix(\".bin\"))\n if ov_model.get_parameters()[0].get_layout().empty:\n ov_model.get_parameters()[0].set_layout(ov.Layout(\"NCHW\"))\n\n # OpenVINO inference modes are 'LATENCY', 'THROUGHPUT' (not recommended), or 'CUMULATIVE_THROUGHPUT'\n inference_mode = \"CUMULATIVE_THROUGHPUT\" if batch > 1 else \"LATENCY\"\n LOGGER.info(f\"Using OpenVINO {inference_mode} mode for batch={batch} inference...\")\n ov_compiled_model = core.compile_model(\n ov_model,\n device_name=device_name,\n config={\"PERFORMANCE_HINT\": inference_mode},\n )\n input_name = ov_compiled_model.input().get_any_name()\n metadata = w.parent / \"metadata.yaml\"\n\n # TensorRT\n elif engine:\n LOGGER.info(f\"Loading {w} for TensorRT inference...\")\n\n if IS_JETSON and check_version(PYTHON_VERSION, \"<=3.8.10\"):\n # fix error: `np.bool` was a deprecated alias for the builtin `bool` for JetPack 4 and JetPack 5 with Python <= 3.8.10\n check_requirements(\"numpy==1.23.5\")\n\n try: # https://developer.nvidia.com/nvidia-tensorrt-download\n import tensorrt as trt # noqa\n except ImportError:\n if LINUX:\n check_requirements(\"tensorrt>7.0.0,!=10.1.0\")\n import tensorrt as trt # noqa\n check_version(trt.__version__, \">=7.0.0\", hard=True)\n check_version(trt.__version__, \"!=10.1.0\", msg=\"https://github.com/ultralytics/ultralytics/pull/14239\")\n if device.type == \"cpu\":\n device = torch.device(\"cuda:0\")\n Binding = namedtuple(\"Binding\", (\"name\", \"dtype\", \"shape\", \"data\", \"ptr\"))\n logger = trt.Logger(trt.Logger.INFO)\n # Read file\n with open(w, \"rb\") as f, trt.Runtime(logger) as runtime:\n try:\n meta_len = int.from_bytes(f.read(4), byteorder=\"little\") # read metadata length\n metadata = json.loads(f.read(meta_len).decode(\"utf-8\")) # read metadata\n dla = metadata.get(\"dla\", None)\n if dla is not None:\n runtime.DLA_core = int(dla)\n except UnicodeDecodeError:\n f.seek(0) # engine file may lack embedded Ultralytics metadata\n model = runtime.deserialize_cuda_engine(f.read()) # read engine\n\n # Model context\n try:\n context = model.create_execution_context()\n except Exception as e: # model is None\n LOGGER.error(f\"TensorRT model exported with a different version than {trt.__version__}\\n\")\n raise e\n\n bindings = OrderedDict()\n output_names = []\n fp16 = False # default updated below\n dynamic = False\n is_trt10 = not hasattr(model, \"num_bindings\")\n num = range(model.num_io_tensors) if is_trt10 else range(model.num_bindings)\n for i in num:\n if is_trt10:\n name = model.get_tensor_name(i)\n dtype = trt.nptype(model.get_tensor_dtype(name))\n is_input = model.get_tensor_mode(name) == trt.TensorIOMode.INPUT\n if is_input:\n if -1 in tuple(model.get_tensor_shape(name)):\n dynamic = True\n context.set_input_shape(name, tuple(model.get_tensor_profile_shape(name, 0)[1]))\n if dtype == np.float16:\n fp16 = True\n else:\n output_names.append(name)\n shape = tuple(context.get_tensor_shape(name))\n else: # TensorRT < 10.0\n name = model.get_binding_name(i)\n dtype = trt.nptype(model.get_binding_dtype(i))\n is_input = model.binding_is_input(i)\n if model.binding_is_input(i):\n if -1 in tuple(model.get_binding_shape(i)): # dynamic\n dynamic = True\n context.set_binding_shape(i, tuple(model.get_profile_shape(0, i)[1]))\n if dtype == np.float16:\n fp16 = True\n else:\n output_names.append(name)\n shape = tuple(context.get_binding_shape(i))\n im = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device)\n bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr()))\n binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())\n batch_size = bindings[\"images\"].shape[0] # if dynamic, this is instead max batch size\n\n # CoreML\n elif coreml:\n LOGGER.info(f\"Loading {w} for CoreML inference...\")\n import coremltools as ct\n\n model = ct.models.MLModel(w)\n metadata = dict(model.user_defined_metadata)\n\n # TF SavedModel\n elif saved_model:\n LOGGER.info(f\"Loading {w} for TensorFlow SavedModel inference...\")\n import tensorflow as tf\n\n keras = False # assume TF1 saved_model\n model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w)\n metadata = Path(w) / \"metadata.yaml\"\n\n # TF GraphDef\n elif pb: # https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt\n LOGGER.info(f\"Loading {w} for TensorFlow GraphDef inference...\")\n import tensorflow as tf\n\n from ultralytics.engine.exporter import gd_outputs\n\n def wrap_frozen_graph(gd, inputs, outputs):\n \"\"\"Wrap frozen graphs for deployment.\"\"\"\n x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=\"\"), []) # wrapped\n ge = x.graph.as_graph_element\n return x.prune(tf.nest.map_structure(ge, inputs), tf.nest.map_structure(ge, outputs))\n\n gd = tf.Graph().as_graph_def() # TF GraphDef\n with open(w, \"rb\") as f:\n gd.ParseFromString(f.read())\n frozen_func = wrap_frozen_graph(gd, inputs=\"x:0\", outputs=gd_outputs(gd))\n try: # find metadata in SavedModel alongside GraphDef\n metadata = next(Path(w).resolve().parent.rglob(f\"{Path(w).stem}_saved_model*/metadata.yaml\"))\n except StopIteration:\n pass\n\n # TFLite or TFLite Edge TPU\n elif tflite or edgetpu: # https://ai.google.dev/edge/litert/microcontrollers/python\n try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu\n from tflite_runtime.interpreter import Interpreter, load_delegate\n except ImportError:\n import tensorflow as tf\n\n Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate\n if edgetpu: # TF Edge TPU https://coral.ai/software/#edgetpu-runtime\n device = device[3:] if str(device).startswith(\"tpu\") else \":0\"\n LOGGER.info(f\"Loading {w} on device {device[1:]} for TensorFlow Lite Edge TPU inference...\")\n delegate = {\"Linux\": \"libedgetpu.so.1\", \"Darwin\": \"libedgetpu.1.dylib\", \"Windows\": \"edgetpu.dll\"}[\n platform.system()\n ]\n interpreter = Interpreter(\n model_path=w,\n experimental_delegates=[load_delegate(delegate, options={\"device\": device})],\n )\n device = \"cpu\" # Required, otherwise PyTorch will try to use the wrong device\n else: # TFLite\n LOGGER.info(f\"Loading {w} for TensorFlow Lite inference...\")\n interpreter = Interpreter(model_path=w) # load TFLite model\n interpreter.allocate_tensors() # allocate\n input_details = interpreter.get_input_details() # inputs\n output_details = interpreter.get_output_details() # outputs\n # Load metadata\n try:\n with zipfile.ZipFile(w, \"r\") as zf:\n name = zf.namelist()[0]\n contents = zf.read(name).decode(\"utf-8\")\n if name == \"metadata.json\": # Custom Ultralytics metadata dict for Python>=3.12\n metadata = json.loads(contents)\n else:\n metadata = ast.literal_eval(contents) # Default tflite-support metadata for Python<=3.11\n except (zipfile.BadZipFile, SyntaxError, ValueError, json.JSONDecodeError):\n pass\n\n # TF.js\n elif tfjs:\n raise NotImplementedError(\"YOLOv8 TF.js inference is not currently supported.\")\n\n # PaddlePaddle\n elif paddle:\n LOGGER.info(f\"Loading {w} for PaddlePaddle inference...\")\n check_requirements(\n \"paddlepaddle-gpu\"\n if torch.cuda.is_available()\n else \"paddlepaddle==3.0.0\" # pin 3.0.0 for ARM64\n if ARM64\n else \"paddlepaddle>=3.0.0\"\n )\n import paddle.inference as pdi # noqa\n\n w = Path(w)\n model_file, params_file = None, None\n if w.is_dir():\n model_file = next(w.rglob(\"*.json\"), None)\n params_file = next(w.rglob(\"*.pdiparams\"), None)\n elif w.suffix == \".pdiparams\":\n model_file = w.with_name(\"model.json\")\n params_file = w\n\n if not (model_file and params_file and model_file.is_file() and params_file.is_file()):\n raise FileNotFoundError(f\"Paddle model not found in {w}. Both .json and .pdiparams files are required.\")\n\n config = pdi.Config(str(model_file), str(params_file))\n if cuda:\n config.enable_use_gpu(memory_pool_init_size_mb=2048, device_id=0)\n predictor = pdi.create_predictor(config)\n input_handle = predictor.get_input_handle(predictor.get_input_names()[0])\n output_names = predictor.get_output_names()\n metadata = w / \"metadata.yaml\"\n\n # MNN\n elif mnn:\n LOGGER.info(f\"Loading {w} for MNN inference...\")\n check_requirements(\"MNN\") # requires MNN\n import os\n\n import MNN\n\n config = {\"precision\": \"low\", \"backend\": \"CPU\", \"numThread\": (os.cpu_count() + 1) // 2}\n rt = MNN.nn.create_runtime_manager((config,))\n net = MNN.nn.load_module_from_file(w, [], [], runtime_manager=rt, rearrange=True)\n\n def torch_to_mnn(x):\n return MNN.expr.const(x.data_ptr(), x.shape)\n\n metadata = json.loads(net.get_info()[\"bizCode\"])\n\n # NCNN\n elif ncnn:\n LOGGER.info(f\"Loading {w} for NCNN inference...\")\n check_requirements(\"git+https://github.com/Tencent/ncnn.git\" if ARM64 else \"ncnn\") # requires NCNN\n import ncnn as pyncnn\n\n net = pyncnn.Net()\n net.opt.use_vulkan_compute = cuda\n w = Path(w)\n if not w.is_file(): # if not *.param\n w = next(w.glob(\"*.param\")) # get *.param file from *_ncnn_model dir\n net.load_param(str(w))\n net.load_model(str(w.with_suffix(\".bin\")))\n metadata = w.parent / \"metadata.yaml\"\n\n # NVIDIA Triton Inference Server\n elif triton:\n check_requirements(\"tritonclient[all]\")\n from ultralytics.utils.triton import TritonRemoteModel\n\n model = TritonRemoteModel(w)\n metadata = model.metadata\n\n # RKNN\n elif rknn:\n if not is_rockchip():\n raise OSError(\"RKNN inference is only supported on Rockchip devices.\")\n LOGGER.info(f\"Loading {w} for RKNN inference...\")\n check_requirements(\"rknn-toolkit-lite2\")\n from rknnlite.api import RKNNLite\n\n w = Path(w)\n if not w.is_file(): # if not *.rknn\n w = next(w.rglob(\"*.rknn\")) # get *.rknn file from *_rknn_model dir\n rknn_model = RKNNLite()\n rknn_model.load_rknn(str(w))\n rknn_model.init_runtime()\n metadata = w.parent / \"metadata.yaml\"\n\n # Any other format (unsupported)\n else:\n from ultralytics.engine.exporter import export_formats\n\n raise TypeError(\n f\"model='{w}' is not a supported model format. Ultralytics supports: {export_formats()['Format']}\\n\"\n f\"See https://docs.ultralytics.com/modes/predict for help.\"\n )\n\n # Load external metadata YAML\n if isinstance(metadata, (str, Path)) and Path(metadata).exists():\n metadata = YAML.load(metadata)\n if metadata and isinstance(metadata, dict):\n for k, v in metadata.items():\n if k in {\"stride\", \"batch\", \"channels\"}:\n metadata[k] = int(v)\n elif k in {\"imgsz\", \"names\", \"kpt_shape\", \"args\"} and isinstance(v, str):\n metadata[k] = eval(v)\n stride = metadata[\"stride\"]\n task = metadata[\"task\"]\n batch = metadata[\"batch\"]\n imgsz = metadata[\"imgsz\"]\n names = metadata[\"names\"]\n kpt_shape = metadata.get(\"kpt_shape\")\n end2end = metadata.get(\"args\", {}).get(\"nms\", False)\n dynamic = metadata.get(\"args\", {}).get(\"dynamic\", dynamic)\n ch = metadata.get(\"channels\", 3)\n elif not (pt or triton or nn_module):\n LOGGER.warning(f\"Metadata not found for 'model={weights}'\")\n\n # Check names\n if \"names\" not in locals(): # names missing\n names = default_class_names(data)\n names = check_class_names(names)\n\n # Disable gradients\n if pt:\n for p in model.parameters():\n p.requires_grad = False\n\n self.__dict__.update(locals()) # assign all variables to self\n\n def forward(\n self,\n im: torch.Tensor,\n augment: bool = False,\n visualize: bool = False,\n embed: Optional[List] = None,\n **kwargs: Any,\n ) -> Union[torch.Tensor, List[torch.Tensor]]:\n \"\"\"\n Run inference on an AutoBackend model.\n\n Args:\n im (torch.Tensor): The image tensor to perform inference on.\n augment (bool): Whether to perform data augmentation during inference.\n visualize (bool): Whether to visualize the output predictions.\n embed (list, optional): A list of feature vectors/embeddings to return.\n **kwargs (Any): Additional keyword arguments for model configuration.\n\n Returns:\n (torch.Tensor | List[torch.Tensor]): The raw output tensor(s) from the model.\n \"\"\"\n b, ch, h, w = im.shape # batch, channel, height, width\n if self.fp16 and im.dtype != torch.float16:\n im = im.half() # to FP16\n if self.nhwc:\n im = im.permute(0, 2, 3, 1) # torch BCHW to numpy BHWC shape(1,320,192,3)\n\n # PyTorch\n if self.pt or self.nn_module:\n y = self.model(im, augment=augment, visualize=visualize, embed=embed, **kwargs)\n\n # TorchScript\n elif self.jit:\n y = self.model(im)\n\n # ONNX OpenCV DNN\n elif self.dnn:\n im = im.cpu().numpy() # torch to numpy\n self.net.setInput(im)\n y = self.net.forward()\n\n # ONNX Runtime\n elif self.onnx or self.imx:\n if self.dynamic:\n im = im.cpu().numpy() # torch to numpy\n y = self.session.run(self.output_names, {self.session.get_inputs()[0].name: im})\n else:\n if not self.cuda:\n im = im.cpu()\n self.io.bind_input(\n name=\"images\",\n device_type=im.device.type,\n device_id=im.device.index if im.device.type == \"cuda\" else 0,\n element_type=np.float16 if self.fp16 else np.float32,\n shape=tuple(im.shape),\n buffer_ptr=im.data_ptr(),\n )\n self.session.run_with_iobinding(self.io)\n y = self.bindings\n if self.imx:\n if self.task == \"detect\":\n # boxes, conf, cls\n y = np.concatenate([y[0], y[1][:, :, None], y[2][:, :, None]], axis=-1)\n elif self.task == \"pose\":\n # boxes, conf, kpts\n y = np.concatenate([y[0], y[1][:, :, None], y[2][:, :, None], y[3]], axis=-1)\n\n # OpenVINO\n elif self.xml:\n im = im.cpu().numpy() # FP32\n\n if self.inference_mode in {\"THROUGHPUT\", \"CUMULATIVE_THROUGHPUT\"}: # optimized for larger batch-sizes\n n = im.shape[0] # number of images in batch\n results = [None] * n # preallocate list with None to match the number of images\n\n def callback(request, userdata):\n \"\"\"Place result in preallocated list using userdata index.\"\"\"\n results[userdata] = request.results\n\n # Create AsyncInferQueue, set the callback and start asynchronous inference for each input image\n async_queue = self.ov.AsyncInferQueue(self.ov_compiled_model)\n async_queue.set_callback(callback)\n for i in range(n):\n # Start async inference with userdata=i to specify the position in results list\n async_queue.start_async(inputs={self.input_name: im[i : i + 1]}, userdata=i) # keep image as BCHW\n async_queue.wait_all() # wait for all inference requests to complete\n y = np.concatenate([list(r.values())[0] for r in results])\n\n else: # inference_mode = \"LATENCY\", optimized for fastest first result at batch-size 1\n y = list(self.ov_compiled_model(im).values())\n\n # TensorRT\n elif self.engine:\n if self.dynamic and im.shape != self.bindings[\"images\"].shape:\n if self.is_trt10:\n self.context.set_input_shape(\"images\", im.shape)\n self.bindings[\"images\"] = self.bindings[\"images\"]._replace(shape=im.shape)\n for name in self.output_names:\n self.bindings[name].data.resize_(tuple(self.context.get_tensor_shape(name)))\n else:\n i = self.model.get_binding_index(\"images\")\n self.context.set_binding_shape(i, im.shape)\n self.bindings[\"images\"] = self.bindings[\"images\"]._replace(shape=im.shape)\n for name in self.output_names:\n i = self.model.get_binding_index(name)\n self.bindings[name].data.resize_(tuple(self.context.get_binding_shape(i)))\n\n s = self.bindings[\"images\"].shape\n assert im.shape == s, f\"input size {im.shape} {'>' if self.dynamic else 'not equal to'} max model size {s}\"\n self.binding_addrs[\"images\"] = int(im.data_ptr())\n self.context.execute_v2(list(self.binding_addrs.values()))\n y = [self.bindings[x].data for x in sorted(self.output_names)]\n\n # CoreML\n elif self.coreml:\n im = im[0].cpu().numpy()\n im_pil = Image.fromarray((im * 255).astype(\"uint8\"))\n # im = im.resize((192, 320), Image.BILINEAR)\n y = self.model.predict({\"image\": im_pil}) # coordinates are xywh normalized\n if \"confidence\" in y:\n raise TypeError(\n \"Ultralytics only supports inference of non-pipelined CoreML models exported with \"\n f\"'nms=False', but 'model={w}' has an NMS pipeline created by an 'nms=True' export.\"\n )\n # TODO: CoreML NMS inference handling\n # from ultralytics.utils.ops import xywh2xyxy\n # box = xywh2xyxy(y['coordinates'] * [[w, h, w, h]]) # xyxy pixels\n # conf, cls = y['confidence'].max(1), y['confidence'].argmax(1).astype(np.float32)\n # y = np.concatenate((box, conf.reshape(-1, 1), cls.reshape(-1, 1)), 1)\n y = list(y.values())\n if len(y) == 2 and len(y[1].shape) != 4: # segmentation model\n y = list(reversed(y)) # reversed for segmentation models (pred, proto)\n\n # PaddlePaddle\n elif self.paddle:\n im = im.cpu().numpy().astype(np.float32)\n self.input_handle.copy_from_cpu(im)\n self.predictor.run()\n y = [self.predictor.get_output_handle(x).copy_to_cpu() for x in self.output_names]\n\n # MNN\n elif self.mnn:\n input_var = self.torch_to_mnn(im)\n output_var = self.net.onForward([input_var])\n y = [x.read() for x in output_var]\n\n # NCNN\n elif self.ncnn:\n mat_in = self.pyncnn.Mat(im[0].cpu().numpy())\n with self.net.create_extractor() as ex:\n ex.input(self.net.input_names()[0], mat_in)\n # WARNING: 'output_names' sorted as a temporary fix for https://github.com/pnnx/pnnx/issues/130\n y = [np.array(ex.extract(x)[1])[None] for x in sorted(self.net.output_names())]\n\n # NVIDIA Triton Inference Server\n elif self.triton:\n im = im.cpu().numpy() # torch to numpy\n y = self.model(im)\n\n # RKNN\n elif self.rknn:\n im = (im.cpu().numpy() * 255).astype(\"uint8\")\n im = im if isinstance(im, (list, tuple)) else [im]\n y = self.rknn_model.inference(inputs=im)\n\n # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)\n else:\n im = im.cpu().numpy()\n if self.saved_model: # SavedModel\n y = self.model(im, training=False) if self.keras else self.model.serving_default(im)\n if not isinstance(y, list):\n y = [y]\n elif self.pb: # GraphDef\n y = self.frozen_func(x=self.tf.constant(im))\n else: # Lite or Edge TPU\n details = self.input_details[0]\n is_int = details[\"dtype\"] in {np.int8, np.int16} # is TFLite quantized int8 or int16 model\n if is_int:\n scale, zero_point = details[\"quantization\"]\n im = (im / scale + zero_point).astype(details[\"dtype\"]) # de-scale\n self.interpreter.set_tensor(details[\"index\"], im)\n self.interpreter.invoke()\n y = []\n for output in self.output_details:\n x = self.interpreter.get_tensor(output[\"index\"])\n if is_int:\n scale, zero_point = output[\"quantization\"]\n x = (x.astype(np.float32) - zero_point) * scale # re-scale\n if x.ndim == 3: # if task is not classification, excluding masks (ndim=4) as well\n # Denormalize xywh by image size. See https://github.com/ultralytics/ultralytics/pull/1695\n # xywh are normalized in TFLite/EdgeTPU to mitigate quantization error of integer models\n if x.shape[-1] == 6 or self.end2end: # end-to-end model\n x[:, :, [0, 2]] *= w\n x[:, :, [1, 3]] *= h\n if self.task == \"pose\":\n x[:, :, 6::3] *= w\n x[:, :, 7::3] *= h\n else:\n x[:, [0, 2]] *= w\n x[:, [1, 3]] *= h\n if self.task == \"pose\":\n x[:, 5::3] *= w\n x[:, 6::3] *= h\n y.append(x)\n # TF segment fixes: export is reversed vs ONNX export and protos are transposed\n if len(y) == 2: # segment with (det, proto) output order reversed\n if len(y[1].shape) != 4:\n y = list(reversed(y)) # should be y = (1, 116, 8400), (1, 160, 160, 32)\n if y[1].shape[-1] == 6: # end-to-end model\n y = [y[1]]\n else:\n y[1] = np.transpose(y[1], (0, 3, 1, 2)) # should be y = (1, 116, 8400), (1, 32, 160, 160)\n y = [x if isinstance(x, np.ndarray) else x.numpy() for x in y]\n\n # for x in y:\n # print(type(x), len(x)) if isinstance(x, (list, tuple)) else print(type(x), x.shape) # debug shapes\n if isinstance(y, (list, tuple)):\n if len(self.names) == 999 and (self.task == \"segment\" or len(y) == 2): # segments and names not defined\n nc = y[0].shape[1] - y[1].shape[1] - 4 # y = (1, 32, 160, 160), (1, 116, 8400)\n self.names = {i: f\"class{i}\" for i in range(nc)}\n return self.from_numpy(y[0]) if len(y) == 1 else [self.from_numpy(x) for x in y]\n else:\n return self.from_numpy(y)\n\n def from_numpy(self, x: np.ndarray) -> torch.Tensor:\n \"\"\"\n Convert a numpy array to a tensor.\n\n Args:\n x (np.ndarray): The array to be converted.\n\n Returns:\n (torch.Tensor): The converted tensor\n \"\"\"\n return torch.tensor(x).to(self.device) if isinstance(x, np.ndarray) else x\n\n def warmup(self, imgsz: Tuple[int, int, int, int] = (1, 3, 640, 640)) -> None:\n \"\"\"\n Warm up the model by running one forward pass with a dummy input.\n\n Args:\n imgsz (tuple): The shape of the dummy input tensor in the format (batch_size, channels, height, width)\n \"\"\"\n import torchvision # noqa (import here so torchvision import time not recorded in postprocess time)\n\n warmup_types = self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb, self.triton, self.nn_module\n if any(warmup_types) and (self.device.type != \"cpu\" or self.triton):\n im = torch.empty(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input\n for _ in range(2 if self.jit else 1):\n self.forward(im) # warmup\n\n @staticmethod\n def _model_type(p: str = \"path/to/model.pt\") -> List[bool]:\n \"\"\"\n Take a path to a model file and return the model type.\n\n Args:\n p (str): Path to the model file.\n\n Returns:\n (List[bool]): List of booleans indicating the model type.\n\n Examples:\n >>> model = AutoBackend(weights=\"path/to/model.onnx\")\n >>> model_type = model._model_type() # returns \"onnx\"\n \"\"\"\n from ultralytics.engine.exporter import export_formats\n\n sf = export_formats()[\"Suffix\"] # export suffixes\n if not is_url(p) and not isinstance(p, str):\n check_suffix(p, sf) # checks\n name = Path(p).name\n types = [s in name for s in sf]\n types[5] |= name.endswith(\".mlmodel\") # retain support for older Apple CoreML *.mlmodel formats\n types[8] &= not types[9] # tflite &= not edgetpu\n if any(types):\n triton = False\n else:\n from urllib.parse import urlsplit\n\n url = urlsplit(p)\n triton = bool(url.netloc) and bool(url.path) and url.scheme in {\"http\", \"grpc\"}\n\n return types + [triton]", "chunk_type": "class", "name": "AutoBackend", "file_path": "ultralytics\\ultralytics\\nn\\autobackend.py", "start_line": 70, "end_line": 895, "start_col": 0, "end_col": 31, "parent_name": null, "docstring": "Handle dynamic backend selection for running inference using Ultralytics YOLO models.\n\nThe AutoBackend class is designed to provide an abstraction layer for various inference engines. It supports a wide\nrange of formats, each with specific naming conventions as outlined below:\n\n Supported Formats and Naming Conventions:\n | Format | File Suffix |\n | --------------------- | ----------------- |\n | PyTorch | *.pt |\n | TorchScript | *.torchscript |\n | ONNX Runtime | *.onnx |\n | ONNX OpenCV DNN | *.onnx (dnn=True) |\n | OpenVINO | *openvino_model/ |\n | CoreML | *.mlpackage |\n | TensorRT | *.engine |\n | TensorFlow SavedModel | *_saved_model/ |\n | TensorFlow GraphDef | *.pb |\n | TensorFlow Lite | *.tflite |\n | TensorFlow Edge TPU | *_edgetpu.tflite |\n | PaddlePaddle | *_paddle_model/ |\n | MNN | *.mnn |\n | NCNN | *_ncnn_model/ |\n | IMX | *_imx_model/ |\n | RKNN | *_rknn_model/ |\n\nAttributes:\n model (torch.nn.Module): The loaded YOLO model.\n device (torch.device): The device (CPU or GPU) on which the model is loaded.\n task (str): The type of task the model performs (detect, segment, classify, pose).\n names (dict): A dictionary of class names that the model can detect.\n stride (int): The model stride, typically 32 for YOLO models.\n fp16 (bool): Whether the model uses half-precision (FP16) inference.\n nhwc (bool): Whether the model expects NHWC input format instead of NCHW.\n pt (bool): Whether the model is a PyTorch model.\n jit (bool): Whether the model is a TorchScript model.\n onnx (bool): Whether the model is an ONNX model.\n xml (bool): Whether the model is an OpenVINO model.\n engine (bool): Whether the model is a TensorRT engine.\n coreml (bool): Whether the model is a CoreML model.\n saved_model (bool): Whether the model is a TensorFlow SavedModel.\n pb (bool): Whether the model is a TensorFlow GraphDef.\n tflite (bool): Whether the model is a TensorFlow Lite model.\n edgetpu (bool): Whether the model is a TensorFlow Edge TPU model.\n tfjs (bool): Whether the model is a TensorFlow.js model.\n paddle (bool): Whether the model is a PaddlePaddle model.\n mnn (bool): Whether the model is an MNN model.\n ncnn (bool): Whether the model is an NCNN model.\n imx (bool): Whether the model is an IMX model.\n rknn (bool): Whether the model is an RKNN model.\n triton (bool): Whether the model is a Triton Inference Server model.\n\nMethods:\n forward: Run inference on an input image.\n from_numpy: Convert numpy array to tensor.\n warmup: Warm up the model with a dummy input.\n _model_type: Determine the model type from file path.\n\nExamples:\n >>> model = AutoBackend(weights=\"yolo11n.pt\", device=\"cuda\")\n >>> results = model(img)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "ast", "json", "platform", "zipfile", "collections.OrderedDict", "collections.namedtuple", "pathlib.Path", "typing.Any", "typing.Dict", "typing.List", "typing.Optional", "typing.Tuple", "typing.Union", "cv2", "numpy", "torch", "torch.nn", "PIL.Image", "ultralytics.utils.ARM64", "ultralytics.utils.IS_JETSON", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.ROOT", "ultralytics.utils.YAML", "ultralytics.utils.checks.check_requirements", "ultralytics.utils.checks.check_suffix", "ultralytics.utils.checks.check_version", "ultralytics.utils.checks.check_yaml", "ultralytics.utils.checks.is_rockchip", "ultralytics.utils.downloads.attempt_download_asset", "ultralytics.utils.downloads.is_url", "torchvision", "ultralytics.engine.exporter.export_formats", "urllib.parse.urlsplit", "ultralytics.nn.tasks.attempt_load_weights", "torchvision", "onnxruntime", "mct_quantizers", "sony_custom_layers.pytorch.nms.nms_ort", "openvino", "tensorrt", "coremltools", "tensorrt", "tensorflow", "tensorflow", "ultralytics.engine.exporter.gd_outputs", "tflite_runtime.interpreter.Interpreter", "tflite_runtime.interpreter.load_delegate", "tensorflow", "paddle.inference", "os", "MNN", "ncnn", "ultralytics.utils.triton.TritonRemoteModel", "rknnlite.api.RKNNLite", "ultralytics.engine.exporter.export_formats", "nn.Module" ], "chunk_id": "class_AutoBackend_810dcaf6" }, { "content": "import contextlib", "chunk_type": "import", "name": "contextlib", "file_path": "ultralytics\\ultralytics\\nn\\tasks.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 17, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_contextlib_47a08d0c" }, { "content": "import pickle", "chunk_type": "import", "name": "pickle", "file_path": "ultralytics\\ultralytics\\nn\\tasks.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 13, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_pickle_8fa91230" }, { "content": "import re", "chunk_type": "import", "name": "re", "file_path": "ultralytics\\ultralytics\\nn\\tasks.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 9, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_re_032c8be8" }, { "content": "import types", "chunk_type": "import", "name": "types", "file_path": "ultralytics\\ultralytics\\nn\\tasks.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_types_e46ed727" }, { "content": "from copy import deepcopy", "chunk_type": "import", "name": "deepcopy", "file_path": "ultralytics\\ultralytics\\nn\\tasks.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 25, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_deepcopy_b572d1ff" }, { "content": "from pathlib import Path", "chunk_type": "import", "name": "Path", "file_path": "ultralytics\\ultralytics\\nn\\tasks.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Path_b73a5b3d" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\ultralytics\\nn\\tasks.py", "start_line": 10, "end_line": 10, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_51dac660" }, { "content": "import torch.nn as nn", "chunk_type": "import", "name": "torch.nn", "file_path": "ultralytics\\ultralytics\\nn\\tasks.py", "start_line": 11, "end_line": 11, "start_col": 0, "end_col": 21, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch.nn_2c981744" }, { "content": "from ultralytics.nn.autobackend import check_class_names", "chunk_type": "import", "name": "check_class_names", "file_path": "ultralytics\\ultralytics\\nn\\tasks.py", "start_line": 13, "end_line": 13, "start_col": 0, "end_col": 56, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_check_class_names_363ea4d7" }, { "content": "from ultralytics.nn.modules import (\n AIFI,\n C1,\n C2,\n C2PSA,\n C3,\n C3TR,\n ELAN1,\n OBB,\n PSA,\n SPP,\n SPPELAN,\n SPPF,\n A2C2f,\n AConv,\n ADown,\n Bottleneck,\n BottleneckCSP,\n C2f,\n C2fAttn,\n C2fCIB,\n C2fPSA,\n C3Ghost,\n C3k2,\n C3x,\n CBFuse,\n CBLinear,\n Classify,\n Concat,\n Conv,\n Conv2,\n ConvTranspose,\n Detect,\n DWConv,\n DWConvTranspose2d,\n Focus,\n GhostBottleneck,\n GhostConv,\n HGBlock,\n HGStem,\n ImagePoolingAttn,\n Index,\n LRPCHead,\n Pose,\n RepC3,\n RepConv,\n RepNCSPELAN4,\n RepVGGDW,\n ResNetLayer,\n RTDETRDecoder,\n SCDown,\n Segment,\n TorchVision,\n WorldDetect,\n YOLOEDetect,\n YOLOESegment,\n v10Detect,\n)", "chunk_type": "import", "name": "AIFI, C1, C2, C2PSA, C3, C3TR, ELAN1, OBB, PSA, SPP, SPPELAN, SPPF, A2C2f, AConv, ADown, Bottleneck, BottleneckCSP, C2f, C2fAttn, C2fCIB, C2fPSA, C3Ghost, C3k2, C3x, CBFuse, CBLinear, Classify, Concat, Conv, Conv2, ConvTranspose, Detect, DWConv, DWConvTranspose2d, Focus, GhostBottleneck, GhostConv, HGBlock, HGStem, ImagePoolingAttn, Index, LRPCHead, Pose, RepC3, RepConv, RepNCSPELAN4, RepVGGDW, ResNetLayer, RTDETRDecoder, SCDown, Segment, TorchVision, WorldDetect, YOLOEDetect, YOLOESegment, v10Detect", "file_path": "ultralytics\\ultralytics\\nn\\tasks.py", "start_line": 14, "end_line": 71, "start_col": 0, "end_col": 1, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_AIFI, C1, C2, C2PSA, C3, C3TR, ELAN1, OBB, PSA, SPP, SPPELAN, SPPF, A2C2f, AConv, ADown, Bottleneck, BottleneckCSP, C2f, C2fAttn, C2fCIB, C2fPSA, C3Ghost, C3k2, C3x, CBFuse, CBLinear, Classify, Concat, Conv, Conv2, ConvTranspose, Detect, DWConv, DWConvTranspose2d, Focus, GhostBottleneck, GhostConv, HGBlock, HGStem, ImagePoolingAttn, Index, LRPCHead, Pose, RepC3, RepConv, RepNCSPELAN4, RepVGGDW, ResNetLayer, RTDETRDecoder, SCDown, Segment, TorchVision, WorldDetect, YOLOEDetect, YOLOESegment, v10Detect_8af932ab" }, { "content": "from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, YAML, colorstr, emojis", "chunk_type": "import", "name": "DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, YAML, colorstr, emojis", "file_path": "ultralytics\\ultralytics\\nn\\tasks.py", "start_line": 72, "end_line": 72, "start_col": 0, "end_col": 96, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, YAML, colorstr, emojis_bbe07ad7" }, { "content": "from ultralytics.utils.checks import check_requirements, check_suffix, check_yaml", "chunk_type": "import", "name": "check_requirements, check_suffix, check_yaml", "file_path": "ultralytics\\ultralytics\\nn\\tasks.py", "start_line": 73, "end_line": 73, "start_col": 0, "end_col": 81, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_check_requirements, check_suffix, check_yaml_a410099c" }, { "content": "from ultralytics.utils.loss import (\n E2EDetectLoss,\n v8ClassificationLoss,\n v8DetectionLoss,\n v8OBBLoss,\n v8PoseLoss,\n v8SegmentationLoss,\n)", "chunk_type": "import", "name": "E2EDetectLoss, v8ClassificationLoss, v8DetectionLoss, v8OBBLoss, v8PoseLoss, v8SegmentationLoss", "file_path": "ultralytics\\ultralytics\\nn\\tasks.py", "start_line": 74, "end_line": 81, "start_col": 0, "end_col": 1, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_E2EDetectLoss, v8ClassificationLoss, v8DetectionLoss, v8OBBLoss, v8PoseLoss, v8SegmentationLoss_31b4689a" }, { "content": "from ultralytics.utils.ops import make_divisible", "chunk_type": "import", "name": "make_divisible", "file_path": "ultralytics\\ultralytics\\nn\\tasks.py", "start_line": 82, "end_line": 82, "start_col": 0, "end_col": 48, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_make_divisible_d5641fbb" }, { "content": "from ultralytics.utils.patches import torch_load", "chunk_type": "import", "name": "torch_load", "file_path": "ultralytics\\ultralytics\\nn\\tasks.py", "start_line": 83, "end_line": 83, "start_col": 0, "end_col": 48, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_load_c530804d" }, { "content": "from ultralytics.utils.plotting import feature_visualization", "chunk_type": "import", "name": "feature_visualization", "file_path": "ultralytics\\ultralytics\\nn\\tasks.py", "start_line": 84, "end_line": 84, "start_col": 0, "end_col": 60, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_feature_visualization_286a415a" }, { "content": "from ultralytics.utils.torch_utils import (\n fuse_conv_and_bn,\n fuse_deconv_and_bn,\n initialize_weights,\n intersect_dicts,\n model_info,\n scale_img,\n smart_inference_mode,\n time_sync,\n)", "chunk_type": "import", "name": "fuse_conv_and_bn, fuse_deconv_and_bn, initialize_weights, intersect_dicts, model_info, scale_img, smart_inference_mode, time_sync", "file_path": "ultralytics\\ultralytics\\nn\\tasks.py", "start_line": 85, "end_line": 94, "start_col": 0, "end_col": 1, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_fuse_conv_and_bn, fuse_deconv_and_bn, initialize_weights, intersect_dicts, model_info, scale_img, smart_inference_mode, time_sync_9f7758dc" }, { "content": "class BaseModel(torch.nn.Module):\n \"\"\"\n Base class for all YOLO models in the Ultralytics family.\n\n This class provides common functionality for YOLO models including forward pass handling, model fusion,\n information display, and weight loading capabilities.\n\n Attributes:\n model (torch.nn.Module): The neural network model.\n save (list): List of layer indices to save outputs from.\n stride (torch.Tensor): Model stride values.\n\n Methods:\n forward: Perform forward pass for training or inference.\n predict: Perform inference on input tensor.\n fuse: Fuse Conv2d and BatchNorm2d layers for optimization.\n info: Print model information.\n load: Load weights into the model.\n loss: Compute loss for training.\n\n Examples:\n Create a BaseModel instance\n >>> model = BaseModel()\n >>> model.info() # Display model information\n \"\"\"\n\n def forward(self, x, *args, **kwargs):\n \"\"\"\n Perform forward pass of the model for either training or inference.\n\n If x is a dict, calculates and returns the loss for training. Otherwise, returns predictions for inference.\n\n Args:\n x (torch.Tensor | dict): Input tensor for inference, or dict with image tensor and labels for training.\n *args (Any): Variable length argument list.\n **kwargs (Any): Arbitrary keyword arguments.\n\n Returns:\n (torch.Tensor): Loss if x is a dict (training), or network predictions (inference).\n \"\"\"\n if isinstance(x, dict): # for cases of training and validating while training.\n return self.loss(x, *args, **kwargs)\n return self.predict(x, *args, **kwargs)\n\n def predict(self, x, profile=False, visualize=False, augment=False, embed=None):\n \"\"\"\n Perform a forward pass through the network.\n\n Args:\n x (torch.Tensor): The input tensor to the model.\n profile (bool): Print the computation time of each layer if True.\n visualize (bool): Save the feature maps of the model if True.\n augment (bool): Augment image during prediction.\n embed (list, optional): A list of feature vectors/embeddings to return.\n\n Returns:\n (torch.Tensor): The last output of the model.\n \"\"\"\n if augment:\n return self._predict_augment(x)\n return self._predict_once(x, profile, visualize, embed)\n\n def _predict_once(self, x, profile=False, visualize=False, embed=None):\n \"\"\"\n Perform a forward pass through the network.\n\n Args:\n x (torch.Tensor): The input tensor to the model.\n profile (bool): Print the computation time of each layer if True.\n visualize (bool): Save the feature maps of the model if True.\n embed (list, optional): A list of feature vectors/embeddings to return.\n\n Returns:\n (torch.Tensor): The last output of the model.\n \"\"\"\n y, dt, embeddings = [], [], [] # outputs\n embed = frozenset(embed) if embed is not None else {-1}\n max_idx = max(embed)\n for m in self.model:\n if m.f != -1: # if not from previous layer\n x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers\n if profile:\n self._profile_one_layer(m, x, dt)\n x = m(x) # run\n y.append(x if m.i in self.save else None) # save output\n if visualize:\n feature_visualization(x, m.type, m.i, save_dir=visualize)\n if m.i in embed:\n embeddings.append(torch.nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten\n if m.i == max_idx:\n return torch.unbind(torch.cat(embeddings, 1), dim=0)\n return x\n\n def _predict_augment(self, x):\n \"\"\"Perform augmentations on input image x and return augmented inference.\"\"\"\n LOGGER.warning(\n f\"{self.__class__.__name__} does not support 'augment=True' prediction. \"\n f\"Reverting to single-scale prediction.\"\n )\n return self._predict_once(x)\n\n def _profile_one_layer(self, m, x, dt):\n \"\"\"\n Profile the computation time and FLOPs of a single layer of the model on a given input.\n\n Args:\n m (torch.nn.Module): The layer to be profiled.\n x (torch.Tensor): The input data to the layer.\n dt (list): A list to store the computation time of the layer.\n \"\"\"\n try:\n import thop\n except ImportError:\n thop = None # conda support without 'ultralytics-thop' installed\n\n c = m == self.model[-1] and isinstance(x, list) # is final layer list, copy input as inplace fix\n flops = thop.profile(m, inputs=[x.copy() if c else x], verbose=False)[0] / 1e9 * 2 if thop else 0 # GFLOPs\n t = time_sync()\n for _ in range(10):\n m(x.copy() if c else x)\n dt.append((time_sync() - t) * 100)\n if m == self.model[0]:\n LOGGER.info(f\"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} module\")\n LOGGER.info(f\"{dt[-1]:10.2f} {flops:10.2f} {m.np:10.0f} {m.type}\")\n if c:\n LOGGER.info(f\"{sum(dt):10.2f} {'-':>10s} {'-':>10s} Total\")\n\n def fuse(self, verbose=True):\n \"\"\"\n Fuse the `Conv2d()` and `BatchNorm2d()` layers of the model into a single layer for improved computation\n efficiency.\n\n Returns:\n (torch.nn.Module): The fused model is returned.\n \"\"\"\n if not self.is_fused():\n for m in self.model.modules():\n if isinstance(m, (Conv, Conv2, DWConv)) and hasattr(m, \"bn\"):\n if isinstance(m, Conv2):\n m.fuse_convs()\n m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv\n delattr(m, \"bn\") # remove batchnorm\n m.forward = m.forward_fuse # update forward\n if isinstance(m, ConvTranspose) and hasattr(m, \"bn\"):\n m.conv_transpose = fuse_deconv_and_bn(m.conv_transpose, m.bn)\n delattr(m, \"bn\") # remove batchnorm\n m.forward = m.forward_fuse # update forward\n if isinstance(m, RepConv):\n m.fuse_convs()\n m.forward = m.forward_fuse # update forward\n if isinstance(m, RepVGGDW):\n m.fuse()\n m.forward = m.forward_fuse\n if isinstance(m, v10Detect):\n m.fuse() # remove one2many head\n if isinstance(m, YOLOEDetect) and hasattr(self, \"pe\"):\n m.fuse(self.pe.to(next(self.model.parameters()).device))\n self.info(verbose=verbose)\n\n return self\n\n def is_fused(self, thresh=10):\n \"\"\"\n Check if the model has less than a certain threshold of BatchNorm layers.\n\n Args:\n thresh (int, optional): The threshold number of BatchNorm layers.\n\n Returns:\n (bool): True if the number of BatchNorm layers in the model is less than the threshold, False otherwise.\n \"\"\"\n bn = tuple(v for k, v in torch.nn.__dict__.items() if \"Norm\" in k) # normalization layers, i.e. BatchNorm2d()\n return sum(isinstance(v, bn) for v in self.modules()) < thresh # True if < 'thresh' BatchNorm layers in model\n\n def info(self, detailed=False, verbose=True, imgsz=640):\n \"\"\"\n Print model information.\n\n Args:\n detailed (bool): If True, prints out detailed information about the model.\n verbose (bool): If True, prints out the model information.\n imgsz (int): The size of the image that the model will be trained on.\n \"\"\"\n return model_info(self, detailed=detailed, verbose=verbose, imgsz=imgsz)\n\n def _apply(self, fn):\n \"\"\"\n Apply a function to all tensors in the model that are not parameters or registered buffers.\n\n Args:\n fn (function): The function to apply to the model.\n\n Returns:\n (BaseModel): An updated BaseModel object.\n \"\"\"\n self = super()._apply(fn)\n m = self.model[-1] # Detect()\n if isinstance(\n m, Detect\n ): # includes all Detect subclasses like Segment, Pose, OBB, WorldDetect, YOLOEDetect, YOLOESegment\n m.stride = fn(m.stride)\n m.anchors = fn(m.anchors)\n m.strides = fn(m.strides)\n return self\n\n def load(self, weights, verbose=True):\n \"\"\"\n Load weights into the model.\n\n Args:\n weights (dict | torch.nn.Module): The pre-trained weights to be loaded.\n verbose (bool, optional): Whether to log the transfer progress.\n \"\"\"\n model = weights[\"model\"] if isinstance(weights, dict) else weights # torchvision models are not dicts\n csd = model.float().state_dict() # checkpoint state_dict as FP32\n updated_csd = intersect_dicts(csd, self.state_dict()) # intersect\n self.load_state_dict(updated_csd, strict=False) # load\n len_updated_csd = len(updated_csd)\n first_conv = \"model.0.conv.weight\" # hard-coded to yolo models for now\n # mostly used to boost multi-channel training\n state_dict = self.state_dict()\n if first_conv not in updated_csd and first_conv in state_dict:\n c1, c2, h, w = state_dict[first_conv].shape\n cc1, cc2, ch, cw = csd[first_conv].shape\n if ch == h and cw == w:\n c1, c2 = min(c1, cc1), min(c2, cc2)\n state_dict[first_conv][:c1, :c2] = csd[first_conv][:c1, :c2]\n len_updated_csd += 1\n if verbose:\n LOGGER.info(f\"Transferred {len_updated_csd}/{len(self.model.state_dict())} items from pretrained weights\")\n\n def loss(self, batch, preds=None):\n \"\"\"\n Compute loss.\n\n Args:\n batch (dict): Batch to compute loss on.\n preds (torch.Tensor | List[torch.Tensor], optional): Predictions.\n \"\"\"\n if getattr(self, \"criterion\", None) is None:\n self.criterion = self.init_criterion()\n\n preds = self.forward(batch[\"img\"]) if preds is None else preds\n return self.criterion(preds, batch)\n\n def init_criterion(self):\n \"\"\"Initialize the loss criterion for the BaseModel.\"\"\"\n raise NotImplementedError(\"compute_loss() needs to be implemented by task heads\")", "chunk_type": "class", "name": "BaseModel", "file_path": "ultralytics\\ultralytics\\nn\\tasks.py", "start_line": 97, "end_line": 344, "start_col": 0, "end_col": 89, "parent_name": null, "docstring": "Base class for all YOLO models in the Ultralytics family.\n\nThis class provides common functionality for YOLO models including forward pass handling, model fusion,\ninformation display, and weight loading capabilities.\n\nAttributes:\n model (torch.nn.Module): The neural network model.\n save (list): List of layer indices to save outputs from.\n stride (torch.Tensor): Model stride values.\n\nMethods:\n forward: Perform forward pass for training or inference.\n predict: Perform inference on input tensor.\n fuse: Fuse Conv2d and BatchNorm2d layers for optimization.\n info: Print model information.\n load: Load weights into the model.\n loss: Compute loss for training.\n\nExamples:\n Create a BaseModel instance\n >>> model = BaseModel()\n >>> model.info() # Display model information", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "contextlib", "pickle", "re", "types", "copy.deepcopy", "pathlib.Path", "torch", "torch.nn", "ultralytics.nn.autobackend.check_class_names", "ultralytics.nn.modules.AIFI", "ultralytics.nn.modules.C1", "ultralytics.nn.modules.C2", "ultralytics.nn.modules.C2PSA", "ultralytics.nn.modules.C3", "ultralytics.nn.modules.C3TR", "ultralytics.nn.modules.ELAN1", "ultralytics.nn.modules.OBB", "ultralytics.nn.modules.PSA", "ultralytics.nn.modules.SPP", "ultralytics.nn.modules.SPPELAN", "ultralytics.nn.modules.SPPF", "ultralytics.nn.modules.A2C2f", "ultralytics.nn.modules.AConv", "ultralytics.nn.modules.ADown", "ultralytics.nn.modules.Bottleneck", "ultralytics.nn.modules.BottleneckCSP", "ultralytics.nn.modules.C2f", "ultralytics.nn.modules.C2fAttn", "ultralytics.nn.modules.C2fCIB", "ultralytics.nn.modules.C2fPSA", "ultralytics.nn.modules.C3Ghost", "ultralytics.nn.modules.C3k2", "ultralytics.nn.modules.C3x", "ultralytics.nn.modules.CBFuse", "ultralytics.nn.modules.CBLinear", "ultralytics.nn.modules.Classify", "ultralytics.nn.modules.Concat", "ultralytics.nn.modules.Conv", "ultralytics.nn.modules.Conv2", "ultralytics.nn.modules.ConvTranspose", "ultralytics.nn.modules.Detect", "ultralytics.nn.modules.DWConv", "ultralytics.nn.modules.DWConvTranspose2d", "ultralytics.nn.modules.Focus", "ultralytics.nn.modules.GhostBottleneck", "ultralytics.nn.modules.GhostConv", "ultralytics.nn.modules.HGBlock", "ultralytics.nn.modules.HGStem", "ultralytics.nn.modules.ImagePoolingAttn", "ultralytics.nn.modules.Index", "ultralytics.nn.modules.LRPCHead", "ultralytics.nn.modules.Pose", "ultralytics.nn.modules.RepC3", "ultralytics.nn.modules.RepConv", "ultralytics.nn.modules.RepNCSPELAN4", "ultralytics.nn.modules.RepVGGDW", "ultralytics.nn.modules.ResNetLayer", "ultralytics.nn.modules.RTDETRDecoder", "ultralytics.nn.modules.SCDown", "ultralytics.nn.modules.Segment", "ultralytics.nn.modules.TorchVision", "ultralytics.nn.modules.WorldDetect", "ultralytics.nn.modules.YOLOEDetect", "ultralytics.nn.modules.YOLOESegment", "ultralytics.nn.modules.v10Detect", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.DEFAULT_CFG_KEYS", "ultralytics.utils.LOGGER", "ultralytics.utils.YAML", "ultralytics.utils.colorstr", "ultralytics.utils.emojis", "ultralytics.utils.checks.check_requirements", "ultralytics.utils.checks.check_suffix", "ultralytics.utils.checks.check_yaml", "ultralytics.utils.loss.E2EDetectLoss", "ultralytics.utils.loss.v8ClassificationLoss", "ultralytics.utils.loss.v8DetectionLoss", "ultralytics.utils.loss.v8OBBLoss", "ultralytics.utils.loss.v8PoseLoss", "ultralytics.utils.loss.v8SegmentationLoss", "ultralytics.utils.ops.make_divisible", "ultralytics.utils.patches.torch_load", "ultralytics.utils.plotting.feature_visualization", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "ultralytics.utils.torch_utils.fuse_deconv_and_bn", "ultralytics.utils.torch_utils.initialize_weights", "ultralytics.utils.torch_utils.intersect_dicts", "ultralytics.utils.torch_utils.model_info", "ultralytics.utils.torch_utils.scale_img", "ultralytics.utils.torch_utils.smart_inference_mode", "ultralytics.utils.torch_utils.time_sync", "sys", "importlib.import_module", "ultralytics.utils.downloads.attempt_download_asset", "ast", "ultralytics.models.utils.loss.RTDETRDetectionLoss", "ultralytics.nn.text_model.build_text_model", "ultralytics.nn.text_model.build_text_model", "thop", "ultralytics.utils.loss.TVPDetectLoss", "ultralytics.utils.loss.TVPSegmentLoss", "torch.nn.Module" ], "chunk_id": "class_BaseModel_ef2d153f" }, { "content": "class DetectionModel(BaseModel):\n \"\"\"\n YOLO detection model.\n\n This class implements the YOLO detection architecture, handling model initialization, forward pass,\n augmented inference, and loss computation for object detection tasks.\n\n Attributes:\n yaml (dict): Model configuration dictionary.\n model (torch.nn.Sequential): The neural network model.\n save (list): List of layer indices to save outputs from.\n names (dict): Class names dictionary.\n inplace (bool): Whether to use inplace operations.\n end2end (bool): Whether the model uses end-to-end detection.\n stride (torch.Tensor): Model stride values.\n\n Methods:\n __init__: Initialize the YOLO detection model.\n _predict_augment: Perform augmented inference.\n _descale_pred: De-scale predictions following augmented inference.\n _clip_augmented: Clip YOLO augmented inference tails.\n init_criterion: Initialize the loss criterion.\n\n Examples:\n Initialize a detection model\n >>> model = DetectionModel(\"yolo11n.yaml\", ch=3, nc=80)\n >>> results = model.predict(image_tensor)\n \"\"\"\n\n def __init__(self, cfg=\"yolo11n.yaml\", ch=3, nc=None, verbose=True):\n \"\"\"\n Initialize the YOLO detection model with the given config and parameters.\n\n Args:\n cfg (str | dict): Model configuration file path or dictionary.\n ch (int): Number of input channels.\n nc (int, optional): Number of classes.\n verbose (bool): Whether to display model information.\n \"\"\"\n super().__init__()\n self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) # cfg dict\n if self.yaml[\"backbone\"][0][2] == \"Silence\":\n LOGGER.warning(\n \"YOLOv9 `Silence` module is deprecated in favor of torch.nn.Identity. \"\n \"Please delete local *.pt file and re-download the latest model checkpoint.\"\n )\n self.yaml[\"backbone\"][0][2] = \"nn.Identity\"\n\n # Define model\n self.yaml[\"channels\"] = ch # save channels\n if nc and nc != self.yaml[\"nc\"]:\n LOGGER.info(f\"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}\")\n self.yaml[\"nc\"] = nc # override YAML value\n self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist\n self.names = {i: f\"{i}\" for i in range(self.yaml[\"nc\"])} # default names dict\n self.inplace = self.yaml.get(\"inplace\", True)\n self.end2end = getattr(self.model[-1], \"end2end\", False)\n\n # Build strides\n m = self.model[-1] # Detect()\n if isinstance(m, Detect): # includes all Detect subclasses like Segment, Pose, OBB, YOLOEDetect, YOLOESegment\n s = 256 # 2x min stride\n m.inplace = self.inplace\n\n def _forward(x):\n \"\"\"Perform a forward pass through the model, handling different Detect subclass types accordingly.\"\"\"\n if self.end2end:\n return self.forward(x)[\"one2many\"]\n return self.forward(x)[0] if isinstance(m, (Segment, YOLOESegment, Pose, OBB)) else self.forward(x)\n\n self.model.eval() # Avoid changing batch statistics until training begins\n m.training = True # Setting it to True to properly return strides\n m.stride = torch.tensor([s / x.shape[-2] for x in _forward(torch.zeros(1, ch, s, s))]) # forward\n self.stride = m.stride\n self.model.train() # Set model back to training(default) mode\n m.bias_init() # only run once\n else:\n self.stride = torch.Tensor([32]) # default stride for i.e. RTDETR\n\n # Init weights, biases\n initialize_weights(self)\n if verbose:\n self.info()\n LOGGER.info(\"\")\n\n def _predict_augment(self, x):\n \"\"\"\n Perform augmentations on input image x and return augmented inference and train outputs.\n\n Args:\n x (torch.Tensor): Input image tensor.\n\n Returns:\n (torch.Tensor): Augmented inference output.\n \"\"\"\n if getattr(self, \"end2end\", False) or self.__class__.__name__ != \"DetectionModel\":\n LOGGER.warning(\"Model does not support 'augment=True', reverting to single-scale prediction.\")\n return self._predict_once(x)\n img_size = x.shape[-2:] # height, width\n s = [1, 0.83, 0.67] # scales\n f = [None, 3, None] # flips (2-ud, 3-lr)\n y = [] # outputs\n for si, fi in zip(s, f):\n xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max()))\n yi = super().predict(xi)[0] # forward\n yi = self._descale_pred(yi, fi, si, img_size)\n y.append(yi)\n y = self._clip_augmented(y) # clip augmented tails\n return torch.cat(y, -1), None # augmented inference, train\n\n @staticmethod\n def _descale_pred(p, flips, scale, img_size, dim=1):\n \"\"\"\n De-scale predictions following augmented inference (inverse operation).\n\n Args:\n p (torch.Tensor): Predictions tensor.\n flips (int): Flip type (0=none, 2=ud, 3=lr).\n scale (float): Scale factor.\n img_size (tuple): Original image size (height, width).\n dim (int): Dimension to split at.\n\n Returns:\n (torch.Tensor): De-scaled predictions.\n \"\"\"\n p[:, :4] /= scale # de-scale\n x, y, wh, cls = p.split((1, 1, 2, p.shape[dim] - 4), dim)\n if flips == 2:\n y = img_size[0] - y # de-flip ud\n elif flips == 3:\n x = img_size[1] - x # de-flip lr\n return torch.cat((x, y, wh, cls), dim)\n\n def _clip_augmented(self, y):\n \"\"\"\n Clip YOLO augmented inference tails.\n\n Args:\n y (List[torch.Tensor]): List of detection tensors.\n\n Returns:\n (List[torch.Tensor]): Clipped detection tensors.\n \"\"\"\n nl = self.model[-1].nl # number of detection layers (P3-P5)\n g = sum(4**x for x in range(nl)) # grid points\n e = 1 # exclude layer count\n i = (y[0].shape[-1] // g) * sum(4**x for x in range(e)) # indices\n y[0] = y[0][..., :-i] # large\n i = (y[-1].shape[-1] // g) * sum(4 ** (nl - 1 - x) for x in range(e)) # indices\n y[-1] = y[-1][..., i:] # small\n return y\n\n def init_criterion(self):\n \"\"\"Initialize the loss criterion for the DetectionModel.\"\"\"\n return E2EDetectLoss(self) if getattr(self, \"end2end\", False) else v8DetectionLoss(self)", "chunk_type": "class", "name": "DetectionModel", "file_path": "ultralytics\\ultralytics\\nn\\tasks.py", "start_line": 347, "end_line": 501, "start_col": 0, "end_col": 96, "parent_name": null, "docstring": "YOLO detection model.\n\nThis class implements the YOLO detection architecture, handling model initialization, forward pass,\naugmented inference, and loss computation for object detection tasks.\n\nAttributes:\n yaml (dict): Model configuration dictionary.\n model (torch.nn.Sequential): The neural network model.\n save (list): List of layer indices to save outputs from.\n names (dict): Class names dictionary.\n inplace (bool): Whether to use inplace operations.\n end2end (bool): Whether the model uses end-to-end detection.\n stride (torch.Tensor): Model stride values.\n\nMethods:\n __init__: Initialize the YOLO detection model.\n _predict_augment: Perform augmented inference.\n _descale_pred: De-scale predictions following augmented inference.\n _clip_augmented: Clip YOLO augmented inference tails.\n init_criterion: Initialize the loss criterion.\n\nExamples:\n Initialize a detection model\n >>> model = DetectionModel(\"yolo11n.yaml\", ch=3, nc=80)\n >>> results = model.predict(image_tensor)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "contextlib", "pickle", "re", "types", "copy.deepcopy", "pathlib.Path", "torch", "torch.nn", "ultralytics.nn.autobackend.check_class_names", "ultralytics.nn.modules.AIFI", "ultralytics.nn.modules.C1", "ultralytics.nn.modules.C2", "ultralytics.nn.modules.C2PSA", "ultralytics.nn.modules.C3", "ultralytics.nn.modules.C3TR", "ultralytics.nn.modules.ELAN1", "ultralytics.nn.modules.OBB", "ultralytics.nn.modules.PSA", "ultralytics.nn.modules.SPP", "ultralytics.nn.modules.SPPELAN", "ultralytics.nn.modules.SPPF", "ultralytics.nn.modules.A2C2f", "ultralytics.nn.modules.AConv", "ultralytics.nn.modules.ADown", "ultralytics.nn.modules.Bottleneck", "ultralytics.nn.modules.BottleneckCSP", "ultralytics.nn.modules.C2f", "ultralytics.nn.modules.C2fAttn", "ultralytics.nn.modules.C2fCIB", "ultralytics.nn.modules.C2fPSA", "ultralytics.nn.modules.C3Ghost", "ultralytics.nn.modules.C3k2", "ultralytics.nn.modules.C3x", "ultralytics.nn.modules.CBFuse", "ultralytics.nn.modules.CBLinear", "ultralytics.nn.modules.Classify", "ultralytics.nn.modules.Concat", "ultralytics.nn.modules.Conv", "ultralytics.nn.modules.Conv2", "ultralytics.nn.modules.ConvTranspose", "ultralytics.nn.modules.Detect", "ultralytics.nn.modules.DWConv", "ultralytics.nn.modules.DWConvTranspose2d", "ultralytics.nn.modules.Focus", "ultralytics.nn.modules.GhostBottleneck", "ultralytics.nn.modules.GhostConv", "ultralytics.nn.modules.HGBlock", "ultralytics.nn.modules.HGStem", "ultralytics.nn.modules.ImagePoolingAttn", "ultralytics.nn.modules.Index", "ultralytics.nn.modules.LRPCHead", "ultralytics.nn.modules.Pose", "ultralytics.nn.modules.RepC3", "ultralytics.nn.modules.RepConv", "ultralytics.nn.modules.RepNCSPELAN4", "ultralytics.nn.modules.RepVGGDW", "ultralytics.nn.modules.ResNetLayer", "ultralytics.nn.modules.RTDETRDecoder", "ultralytics.nn.modules.SCDown", "ultralytics.nn.modules.Segment", "ultralytics.nn.modules.TorchVision", "ultralytics.nn.modules.WorldDetect", "ultralytics.nn.modules.YOLOEDetect", "ultralytics.nn.modules.YOLOESegment", "ultralytics.nn.modules.v10Detect", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.DEFAULT_CFG_KEYS", "ultralytics.utils.LOGGER", "ultralytics.utils.YAML", "ultralytics.utils.colorstr", "ultralytics.utils.emojis", "ultralytics.utils.checks.check_requirements", "ultralytics.utils.checks.check_suffix", "ultralytics.utils.checks.check_yaml", "ultralytics.utils.loss.E2EDetectLoss", "ultralytics.utils.loss.v8ClassificationLoss", "ultralytics.utils.loss.v8DetectionLoss", "ultralytics.utils.loss.v8OBBLoss", "ultralytics.utils.loss.v8PoseLoss", "ultralytics.utils.loss.v8SegmentationLoss", "ultralytics.utils.ops.make_divisible", "ultralytics.utils.patches.torch_load", "ultralytics.utils.plotting.feature_visualization", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "ultralytics.utils.torch_utils.fuse_deconv_and_bn", "ultralytics.utils.torch_utils.initialize_weights", "ultralytics.utils.torch_utils.intersect_dicts", "ultralytics.utils.torch_utils.model_info", "ultralytics.utils.torch_utils.scale_img", "ultralytics.utils.torch_utils.smart_inference_mode", "ultralytics.utils.torch_utils.time_sync", "sys", "importlib.import_module", "ultralytics.utils.downloads.attempt_download_asset", "ast", "ultralytics.models.utils.loss.RTDETRDetectionLoss", "ultralytics.nn.text_model.build_text_model", "ultralytics.nn.text_model.build_text_model", "thop", "ultralytics.utils.loss.TVPDetectLoss", "ultralytics.utils.loss.TVPSegmentLoss", "BaseModel" ], "chunk_id": "class_DetectionModel_18799f43" }, { "content": "class OBBModel(DetectionModel):\n \"\"\"\n YOLO Oriented Bounding Box (OBB) model.\n\n This class extends DetectionModel to handle oriented bounding box detection tasks, providing specialized\n loss computation for rotated object detection.\n\n Methods:\n __init__: Initialize YOLO OBB model.\n init_criterion: Initialize the loss criterion for OBB detection.\n\n Examples:\n Initialize an OBB model\n >>> model = OBBModel(\"yolo11n-obb.yaml\", ch=3, nc=80)\n >>> results = model.predict(image_tensor)\n \"\"\"\n\n def __init__(self, cfg=\"yolo11n-obb.yaml\", ch=3, nc=None, verbose=True):\n \"\"\"\n Initialize YOLO OBB model with given config and parameters.\n\n Args:\n cfg (str | dict): Model configuration file path or dictionary.\n ch (int): Number of input channels.\n nc (int, optional): Number of classes.\n verbose (bool): Whether to display model information.\n \"\"\"\n super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)\n\n def init_criterion(self):\n \"\"\"Initialize the loss criterion for the model.\"\"\"\n return v8OBBLoss(self)", "chunk_type": "class", "name": "OBBModel", "file_path": "ultralytics\\ultralytics\\nn\\tasks.py", "start_line": 504, "end_line": 535, "start_col": 0, "end_col": 30, "parent_name": null, "docstring": "YOLO Oriented Bounding Box (OBB) model.\n\nThis class extends DetectionModel to handle oriented bounding box detection tasks, providing specialized\nloss computation for rotated object detection.\n\nMethods:\n __init__: Initialize YOLO OBB model.\n init_criterion: Initialize the loss criterion for OBB detection.\n\nExamples:\n Initialize an OBB model\n >>> model = OBBModel(\"yolo11n-obb.yaml\", ch=3, nc=80)\n >>> results = model.predict(image_tensor)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "contextlib", "pickle", "re", "types", "copy.deepcopy", "pathlib.Path", "torch", "torch.nn", "ultralytics.nn.autobackend.check_class_names", "ultralytics.nn.modules.AIFI", "ultralytics.nn.modules.C1", "ultralytics.nn.modules.C2", "ultralytics.nn.modules.C2PSA", "ultralytics.nn.modules.C3", "ultralytics.nn.modules.C3TR", "ultralytics.nn.modules.ELAN1", "ultralytics.nn.modules.OBB", "ultralytics.nn.modules.PSA", "ultralytics.nn.modules.SPP", "ultralytics.nn.modules.SPPELAN", "ultralytics.nn.modules.SPPF", "ultralytics.nn.modules.A2C2f", "ultralytics.nn.modules.AConv", "ultralytics.nn.modules.ADown", "ultralytics.nn.modules.Bottleneck", "ultralytics.nn.modules.BottleneckCSP", "ultralytics.nn.modules.C2f", "ultralytics.nn.modules.C2fAttn", "ultralytics.nn.modules.C2fCIB", "ultralytics.nn.modules.C2fPSA", "ultralytics.nn.modules.C3Ghost", "ultralytics.nn.modules.C3k2", "ultralytics.nn.modules.C3x", "ultralytics.nn.modules.CBFuse", "ultralytics.nn.modules.CBLinear", "ultralytics.nn.modules.Classify", "ultralytics.nn.modules.Concat", "ultralytics.nn.modules.Conv", "ultralytics.nn.modules.Conv2", "ultralytics.nn.modules.ConvTranspose", "ultralytics.nn.modules.Detect", "ultralytics.nn.modules.DWConv", "ultralytics.nn.modules.DWConvTranspose2d", "ultralytics.nn.modules.Focus", "ultralytics.nn.modules.GhostBottleneck", "ultralytics.nn.modules.GhostConv", "ultralytics.nn.modules.HGBlock", "ultralytics.nn.modules.HGStem", "ultralytics.nn.modules.ImagePoolingAttn", "ultralytics.nn.modules.Index", "ultralytics.nn.modules.LRPCHead", "ultralytics.nn.modules.Pose", "ultralytics.nn.modules.RepC3", "ultralytics.nn.modules.RepConv", "ultralytics.nn.modules.RepNCSPELAN4", "ultralytics.nn.modules.RepVGGDW", "ultralytics.nn.modules.ResNetLayer", "ultralytics.nn.modules.RTDETRDecoder", "ultralytics.nn.modules.SCDown", "ultralytics.nn.modules.Segment", "ultralytics.nn.modules.TorchVision", "ultralytics.nn.modules.WorldDetect", "ultralytics.nn.modules.YOLOEDetect", "ultralytics.nn.modules.YOLOESegment", "ultralytics.nn.modules.v10Detect", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.DEFAULT_CFG_KEYS", "ultralytics.utils.LOGGER", "ultralytics.utils.YAML", "ultralytics.utils.colorstr", "ultralytics.utils.emojis", "ultralytics.utils.checks.check_requirements", "ultralytics.utils.checks.check_suffix", "ultralytics.utils.checks.check_yaml", "ultralytics.utils.loss.E2EDetectLoss", "ultralytics.utils.loss.v8ClassificationLoss", "ultralytics.utils.loss.v8DetectionLoss", "ultralytics.utils.loss.v8OBBLoss", "ultralytics.utils.loss.v8PoseLoss", "ultralytics.utils.loss.v8SegmentationLoss", "ultralytics.utils.ops.make_divisible", "ultralytics.utils.patches.torch_load", "ultralytics.utils.plotting.feature_visualization", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "ultralytics.utils.torch_utils.fuse_deconv_and_bn", "ultralytics.utils.torch_utils.initialize_weights", "ultralytics.utils.torch_utils.intersect_dicts", "ultralytics.utils.torch_utils.model_info", "ultralytics.utils.torch_utils.scale_img", "ultralytics.utils.torch_utils.smart_inference_mode", "ultralytics.utils.torch_utils.time_sync", "sys", "importlib.import_module", "ultralytics.utils.downloads.attempt_download_asset", "ast", "ultralytics.models.utils.loss.RTDETRDetectionLoss", "ultralytics.nn.text_model.build_text_model", "ultralytics.nn.text_model.build_text_model", "thop", "ultralytics.utils.loss.TVPDetectLoss", "ultralytics.utils.loss.TVPSegmentLoss", "DetectionModel" ], "chunk_id": "class_OBBModel_2e6268cb" }, { "content": "class SegmentationModel(DetectionModel):\n \"\"\"\n YOLO segmentation model.\n\n This class extends DetectionModel to handle instance segmentation tasks, providing specialized\n loss computation for pixel-level object detection and segmentation.\n\n Methods:\n __init__: Initialize YOLO segmentation model.\n init_criterion: Initialize the loss criterion for segmentation.\n\n Examples:\n Initialize a segmentation model\n >>> model = SegmentationModel(\"yolo11n-seg.yaml\", ch=3, nc=80)\n >>> results = model.predict(image_tensor)\n \"\"\"\n\n def __init__(self, cfg=\"yolo11n-seg.yaml\", ch=3, nc=None, verbose=True):\n \"\"\"\n Initialize Ultralytics YOLO segmentation model with given config and parameters.\n\n Args:\n cfg (str | dict): Model configuration file path or dictionary.\n ch (int): Number of input channels.\n nc (int, optional): Number of classes.\n verbose (bool): Whether to display model information.\n \"\"\"\n super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)\n\n def init_criterion(self):\n \"\"\"Initialize the loss criterion for the SegmentationModel.\"\"\"\n return v8SegmentationLoss(self)", "chunk_type": "class", "name": "SegmentationModel", "file_path": "ultralytics\\ultralytics\\nn\\tasks.py", "start_line": 538, "end_line": 569, "start_col": 0, "end_col": 39, "parent_name": null, "docstring": "YOLO segmentation model.\n\nThis class extends DetectionModel to handle instance segmentation tasks, providing specialized\nloss computation for pixel-level object detection and segmentation.\n\nMethods:\n __init__: Initialize YOLO segmentation model.\n init_criterion: Initialize the loss criterion for segmentation.\n\nExamples:\n Initialize a segmentation model\n >>> model = SegmentationModel(\"yolo11n-seg.yaml\", ch=3, nc=80)\n >>> results = model.predict(image_tensor)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "contextlib", "pickle", "re", "types", "copy.deepcopy", "pathlib.Path", "torch", "torch.nn", "ultralytics.nn.autobackend.check_class_names", "ultralytics.nn.modules.AIFI", "ultralytics.nn.modules.C1", "ultralytics.nn.modules.C2", "ultralytics.nn.modules.C2PSA", "ultralytics.nn.modules.C3", "ultralytics.nn.modules.C3TR", "ultralytics.nn.modules.ELAN1", "ultralytics.nn.modules.OBB", "ultralytics.nn.modules.PSA", "ultralytics.nn.modules.SPP", "ultralytics.nn.modules.SPPELAN", "ultralytics.nn.modules.SPPF", "ultralytics.nn.modules.A2C2f", "ultralytics.nn.modules.AConv", "ultralytics.nn.modules.ADown", "ultralytics.nn.modules.Bottleneck", "ultralytics.nn.modules.BottleneckCSP", "ultralytics.nn.modules.C2f", "ultralytics.nn.modules.C2fAttn", "ultralytics.nn.modules.C2fCIB", "ultralytics.nn.modules.C2fPSA", "ultralytics.nn.modules.C3Ghost", "ultralytics.nn.modules.C3k2", "ultralytics.nn.modules.C3x", "ultralytics.nn.modules.CBFuse", "ultralytics.nn.modules.CBLinear", "ultralytics.nn.modules.Classify", "ultralytics.nn.modules.Concat", "ultralytics.nn.modules.Conv", "ultralytics.nn.modules.Conv2", "ultralytics.nn.modules.ConvTranspose", "ultralytics.nn.modules.Detect", "ultralytics.nn.modules.DWConv", "ultralytics.nn.modules.DWConvTranspose2d", "ultralytics.nn.modules.Focus", "ultralytics.nn.modules.GhostBottleneck", "ultralytics.nn.modules.GhostConv", "ultralytics.nn.modules.HGBlock", "ultralytics.nn.modules.HGStem", "ultralytics.nn.modules.ImagePoolingAttn", "ultralytics.nn.modules.Index", "ultralytics.nn.modules.LRPCHead", "ultralytics.nn.modules.Pose", "ultralytics.nn.modules.RepC3", "ultralytics.nn.modules.RepConv", "ultralytics.nn.modules.RepNCSPELAN4", "ultralytics.nn.modules.RepVGGDW", "ultralytics.nn.modules.ResNetLayer", "ultralytics.nn.modules.RTDETRDecoder", "ultralytics.nn.modules.SCDown", "ultralytics.nn.modules.Segment", "ultralytics.nn.modules.TorchVision", "ultralytics.nn.modules.WorldDetect", "ultralytics.nn.modules.YOLOEDetect", "ultralytics.nn.modules.YOLOESegment", "ultralytics.nn.modules.v10Detect", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.DEFAULT_CFG_KEYS", "ultralytics.utils.LOGGER", "ultralytics.utils.YAML", "ultralytics.utils.colorstr", "ultralytics.utils.emojis", "ultralytics.utils.checks.check_requirements", "ultralytics.utils.checks.check_suffix", "ultralytics.utils.checks.check_yaml", "ultralytics.utils.loss.E2EDetectLoss", "ultralytics.utils.loss.v8ClassificationLoss", "ultralytics.utils.loss.v8DetectionLoss", "ultralytics.utils.loss.v8OBBLoss", "ultralytics.utils.loss.v8PoseLoss", "ultralytics.utils.loss.v8SegmentationLoss", "ultralytics.utils.ops.make_divisible", "ultralytics.utils.patches.torch_load", "ultralytics.utils.plotting.feature_visualization", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "ultralytics.utils.torch_utils.fuse_deconv_and_bn", "ultralytics.utils.torch_utils.initialize_weights", "ultralytics.utils.torch_utils.intersect_dicts", "ultralytics.utils.torch_utils.model_info", "ultralytics.utils.torch_utils.scale_img", "ultralytics.utils.torch_utils.smart_inference_mode", "ultralytics.utils.torch_utils.time_sync", "sys", "importlib.import_module", "ultralytics.utils.downloads.attempt_download_asset", "ast", "ultralytics.models.utils.loss.RTDETRDetectionLoss", "ultralytics.nn.text_model.build_text_model", "ultralytics.nn.text_model.build_text_model", "thop", "ultralytics.utils.loss.TVPDetectLoss", "ultralytics.utils.loss.TVPSegmentLoss", "DetectionModel" ], "chunk_id": "class_SegmentationModel_3740523f" }, { "content": "class PoseModel(DetectionModel):\n \"\"\"\n YOLO pose model.\n\n This class extends DetectionModel to handle human pose estimation tasks, providing specialized\n loss computation for keypoint detection and pose estimation.\n\n Attributes:\n kpt_shape (tuple): Shape of keypoints data (num_keypoints, num_dimensions).\n\n Methods:\n __init__: Initialize YOLO pose model.\n init_criterion: Initialize the loss criterion for pose estimation.\n\n Examples:\n Initialize a pose model\n >>> model = PoseModel(\"yolo11n-pose.yaml\", ch=3, nc=1, data_kpt_shape=(17, 3))\n >>> results = model.predict(image_tensor)\n \"\"\"\n\n def __init__(self, cfg=\"yolo11n-pose.yaml\", ch=3, nc=None, data_kpt_shape=(None, None), verbose=True):\n \"\"\"\n Initialize Ultralytics YOLO Pose model.\n\n Args:\n cfg (str | dict): Model configuration file path or dictionary.\n ch (int): Number of input channels.\n nc (int, optional): Number of classes.\n data_kpt_shape (tuple): Shape of keypoints data.\n verbose (bool): Whether to display model information.\n \"\"\"\n if not isinstance(cfg, dict):\n cfg = yaml_model_load(cfg) # load model YAML\n if any(data_kpt_shape) and list(data_kpt_shape) != list(cfg[\"kpt_shape\"]):\n LOGGER.info(f\"Overriding model.yaml kpt_shape={cfg['kpt_shape']} with kpt_shape={data_kpt_shape}\")\n cfg[\"kpt_shape\"] = data_kpt_shape\n super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)\n\n def init_criterion(self):\n \"\"\"Initialize the loss criterion for the PoseModel.\"\"\"\n return v8PoseLoss(self)", "chunk_type": "class", "name": "PoseModel", "file_path": "ultralytics\\ultralytics\\nn\\tasks.py", "start_line": 572, "end_line": 612, "start_col": 0, "end_col": 31, "parent_name": null, "docstring": "YOLO pose model.\n\nThis class extends DetectionModel to handle human pose estimation tasks, providing specialized\nloss computation for keypoint detection and pose estimation.\n\nAttributes:\n kpt_shape (tuple): Shape of keypoints data (num_keypoints, num_dimensions).\n\nMethods:\n __init__: Initialize YOLO pose model.\n init_criterion: Initialize the loss criterion for pose estimation.\n\nExamples:\n Initialize a pose model\n >>> model = PoseModel(\"yolo11n-pose.yaml\", ch=3, nc=1, data_kpt_shape=(17, 3))\n >>> results = model.predict(image_tensor)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "contextlib", "pickle", "re", "types", "copy.deepcopy", "pathlib.Path", "torch", "torch.nn", "ultralytics.nn.autobackend.check_class_names", "ultralytics.nn.modules.AIFI", "ultralytics.nn.modules.C1", "ultralytics.nn.modules.C2", "ultralytics.nn.modules.C2PSA", "ultralytics.nn.modules.C3", "ultralytics.nn.modules.C3TR", "ultralytics.nn.modules.ELAN1", "ultralytics.nn.modules.OBB", "ultralytics.nn.modules.PSA", "ultralytics.nn.modules.SPP", "ultralytics.nn.modules.SPPELAN", "ultralytics.nn.modules.SPPF", "ultralytics.nn.modules.A2C2f", "ultralytics.nn.modules.AConv", "ultralytics.nn.modules.ADown", "ultralytics.nn.modules.Bottleneck", "ultralytics.nn.modules.BottleneckCSP", "ultralytics.nn.modules.C2f", "ultralytics.nn.modules.C2fAttn", "ultralytics.nn.modules.C2fCIB", "ultralytics.nn.modules.C2fPSA", "ultralytics.nn.modules.C3Ghost", "ultralytics.nn.modules.C3k2", "ultralytics.nn.modules.C3x", "ultralytics.nn.modules.CBFuse", "ultralytics.nn.modules.CBLinear", "ultralytics.nn.modules.Classify", "ultralytics.nn.modules.Concat", "ultralytics.nn.modules.Conv", "ultralytics.nn.modules.Conv2", "ultralytics.nn.modules.ConvTranspose", "ultralytics.nn.modules.Detect", "ultralytics.nn.modules.DWConv", "ultralytics.nn.modules.DWConvTranspose2d", "ultralytics.nn.modules.Focus", "ultralytics.nn.modules.GhostBottleneck", "ultralytics.nn.modules.GhostConv", "ultralytics.nn.modules.HGBlock", "ultralytics.nn.modules.HGStem", "ultralytics.nn.modules.ImagePoolingAttn", "ultralytics.nn.modules.Index", "ultralytics.nn.modules.LRPCHead", "ultralytics.nn.modules.Pose", "ultralytics.nn.modules.RepC3", "ultralytics.nn.modules.RepConv", "ultralytics.nn.modules.RepNCSPELAN4", "ultralytics.nn.modules.RepVGGDW", "ultralytics.nn.modules.ResNetLayer", "ultralytics.nn.modules.RTDETRDecoder", "ultralytics.nn.modules.SCDown", "ultralytics.nn.modules.Segment", "ultralytics.nn.modules.TorchVision", "ultralytics.nn.modules.WorldDetect", "ultralytics.nn.modules.YOLOEDetect", "ultralytics.nn.modules.YOLOESegment", "ultralytics.nn.modules.v10Detect", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.DEFAULT_CFG_KEYS", "ultralytics.utils.LOGGER", "ultralytics.utils.YAML", "ultralytics.utils.colorstr", "ultralytics.utils.emojis", "ultralytics.utils.checks.check_requirements", "ultralytics.utils.checks.check_suffix", "ultralytics.utils.checks.check_yaml", "ultralytics.utils.loss.E2EDetectLoss", "ultralytics.utils.loss.v8ClassificationLoss", "ultralytics.utils.loss.v8DetectionLoss", "ultralytics.utils.loss.v8OBBLoss", "ultralytics.utils.loss.v8PoseLoss", "ultralytics.utils.loss.v8SegmentationLoss", "ultralytics.utils.ops.make_divisible", "ultralytics.utils.patches.torch_load", "ultralytics.utils.plotting.feature_visualization", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "ultralytics.utils.torch_utils.fuse_deconv_and_bn", "ultralytics.utils.torch_utils.initialize_weights", "ultralytics.utils.torch_utils.intersect_dicts", "ultralytics.utils.torch_utils.model_info", "ultralytics.utils.torch_utils.scale_img", "ultralytics.utils.torch_utils.smart_inference_mode", "ultralytics.utils.torch_utils.time_sync", "sys", "importlib.import_module", "ultralytics.utils.downloads.attempt_download_asset", "ast", "ultralytics.models.utils.loss.RTDETRDetectionLoss", "ultralytics.nn.text_model.build_text_model", "ultralytics.nn.text_model.build_text_model", "thop", "ultralytics.utils.loss.TVPDetectLoss", "ultralytics.utils.loss.TVPSegmentLoss", "DetectionModel" ], "chunk_id": "class_PoseModel_2cecb20e" }, { "content": "class ClassificationModel(BaseModel):\n \"\"\"\n YOLO classification model.\n\n This class implements the YOLO classification architecture for image classification tasks,\n providing model initialization, configuration, and output reshaping capabilities.\n\n Attributes:\n yaml (dict): Model configuration dictionary.\n model (torch.nn.Sequential): The neural network model.\n stride (torch.Tensor): Model stride values.\n names (dict): Class names dictionary.\n\n Methods:\n __init__: Initialize ClassificationModel.\n _from_yaml: Set model configurations and define architecture.\n reshape_outputs: Update model to specified class count.\n init_criterion: Initialize the loss criterion.\n\n Examples:\n Initialize a classification model\n >>> model = ClassificationModel(\"yolo11n-cls.yaml\", ch=3, nc=1000)\n >>> results = model.predict(image_tensor)\n \"\"\"\n\n def __init__(self, cfg=\"yolo11n-cls.yaml\", ch=3, nc=None, verbose=True):\n \"\"\"\n Initialize ClassificationModel with YAML, channels, number of classes, verbose flag.\n\n Args:\n cfg (str | dict): Model configuration file path or dictionary.\n ch (int): Number of input channels.\n nc (int, optional): Number of classes.\n verbose (bool): Whether to display model information.\n \"\"\"\n super().__init__()\n self._from_yaml(cfg, ch, nc, verbose)\n\n def _from_yaml(self, cfg, ch, nc, verbose):\n \"\"\"\n Set Ultralytics YOLO model configurations and define the model architecture.\n\n Args:\n cfg (str | dict): Model configuration file path or dictionary.\n ch (int): Number of input channels.\n nc (int, optional): Number of classes.\n verbose (bool): Whether to display model information.\n \"\"\"\n self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) # cfg dict\n\n # Define model\n ch = self.yaml[\"channels\"] = self.yaml.get(\"channels\", ch) # input channels\n if nc and nc != self.yaml[\"nc\"]:\n LOGGER.info(f\"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}\")\n self.yaml[\"nc\"] = nc # override YAML value\n elif not nc and not self.yaml.get(\"nc\", None):\n raise ValueError(\"nc not specified. Must specify nc in model.yaml or function arguments.\")\n self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist\n self.stride = torch.Tensor([1]) # no stride constraints\n self.names = {i: f\"{i}\" for i in range(self.yaml[\"nc\"])} # default names dict\n self.info()\n\n @staticmethod\n def reshape_outputs(model, nc):\n \"\"\"\n Update a TorchVision classification model to class count 'n' if required.\n\n Args:\n model (torch.nn.Module): Model to update.\n nc (int): New number of classes.\n \"\"\"\n name, m = list((model.model if hasattr(model, \"model\") else model).named_children())[-1] # last module\n if isinstance(m, Classify): # YOLO Classify() head\n if m.linear.out_features != nc:\n m.linear = torch.nn.Linear(m.linear.in_features, nc)\n elif isinstance(m, torch.nn.Linear): # ResNet, EfficientNet\n if m.out_features != nc:\n setattr(model, name, torch.nn.Linear(m.in_features, nc))\n elif isinstance(m, torch.nn.Sequential):\n types = [type(x) for x in m]\n if torch.nn.Linear in types:\n i = len(types) - 1 - types[::-1].index(torch.nn.Linear) # last torch.nn.Linear index\n if m[i].out_features != nc:\n m[i] = torch.nn.Linear(m[i].in_features, nc)\n elif torch.nn.Conv2d in types:\n i = len(types) - 1 - types[::-1].index(torch.nn.Conv2d) # last torch.nn.Conv2d index\n if m[i].out_channels != nc:\n m[i] = torch.nn.Conv2d(\n m[i].in_channels, nc, m[i].kernel_size, m[i].stride, bias=m[i].bias is not None\n )\n\n def init_criterion(self):\n \"\"\"Initialize the loss criterion for the ClassificationModel.\"\"\"\n return v8ClassificationLoss()", "chunk_type": "class", "name": "ClassificationModel", "file_path": "ultralytics\\ultralytics\\nn\\tasks.py", "start_line": 615, "end_line": 708, "start_col": 0, "end_col": 37, "parent_name": null, "docstring": "YOLO classification model.\n\nThis class implements the YOLO classification architecture for image classification tasks,\nproviding model initialization, configuration, and output reshaping capabilities.\n\nAttributes:\n yaml (dict): Model configuration dictionary.\n model (torch.nn.Sequential): The neural network model.\n stride (torch.Tensor): Model stride values.\n names (dict): Class names dictionary.\n\nMethods:\n __init__: Initialize ClassificationModel.\n _from_yaml: Set model configurations and define architecture.\n reshape_outputs: Update model to specified class count.\n init_criterion: Initialize the loss criterion.\n\nExamples:\n Initialize a classification model\n >>> model = ClassificationModel(\"yolo11n-cls.yaml\", ch=3, nc=1000)\n >>> results = model.predict(image_tensor)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "contextlib", "pickle", "re", "types", "copy.deepcopy", "pathlib.Path", "torch", "torch.nn", "ultralytics.nn.autobackend.check_class_names", "ultralytics.nn.modules.AIFI", "ultralytics.nn.modules.C1", "ultralytics.nn.modules.C2", "ultralytics.nn.modules.C2PSA", "ultralytics.nn.modules.C3", "ultralytics.nn.modules.C3TR", "ultralytics.nn.modules.ELAN1", "ultralytics.nn.modules.OBB", "ultralytics.nn.modules.PSA", "ultralytics.nn.modules.SPP", "ultralytics.nn.modules.SPPELAN", "ultralytics.nn.modules.SPPF", "ultralytics.nn.modules.A2C2f", "ultralytics.nn.modules.AConv", "ultralytics.nn.modules.ADown", "ultralytics.nn.modules.Bottleneck", "ultralytics.nn.modules.BottleneckCSP", "ultralytics.nn.modules.C2f", "ultralytics.nn.modules.C2fAttn", "ultralytics.nn.modules.C2fCIB", "ultralytics.nn.modules.C2fPSA", "ultralytics.nn.modules.C3Ghost", "ultralytics.nn.modules.C3k2", "ultralytics.nn.modules.C3x", "ultralytics.nn.modules.CBFuse", "ultralytics.nn.modules.CBLinear", "ultralytics.nn.modules.Classify", "ultralytics.nn.modules.Concat", "ultralytics.nn.modules.Conv", "ultralytics.nn.modules.Conv2", "ultralytics.nn.modules.ConvTranspose", "ultralytics.nn.modules.Detect", "ultralytics.nn.modules.DWConv", "ultralytics.nn.modules.DWConvTranspose2d", "ultralytics.nn.modules.Focus", "ultralytics.nn.modules.GhostBottleneck", "ultralytics.nn.modules.GhostConv", "ultralytics.nn.modules.HGBlock", "ultralytics.nn.modules.HGStem", "ultralytics.nn.modules.ImagePoolingAttn", "ultralytics.nn.modules.Index", "ultralytics.nn.modules.LRPCHead", "ultralytics.nn.modules.Pose", "ultralytics.nn.modules.RepC3", "ultralytics.nn.modules.RepConv", "ultralytics.nn.modules.RepNCSPELAN4", "ultralytics.nn.modules.RepVGGDW", "ultralytics.nn.modules.ResNetLayer", "ultralytics.nn.modules.RTDETRDecoder", "ultralytics.nn.modules.SCDown", "ultralytics.nn.modules.Segment", "ultralytics.nn.modules.TorchVision", "ultralytics.nn.modules.WorldDetect", "ultralytics.nn.modules.YOLOEDetect", "ultralytics.nn.modules.YOLOESegment", "ultralytics.nn.modules.v10Detect", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.DEFAULT_CFG_KEYS", "ultralytics.utils.LOGGER", "ultralytics.utils.YAML", "ultralytics.utils.colorstr", "ultralytics.utils.emojis", "ultralytics.utils.checks.check_requirements", "ultralytics.utils.checks.check_suffix", "ultralytics.utils.checks.check_yaml", "ultralytics.utils.loss.E2EDetectLoss", "ultralytics.utils.loss.v8ClassificationLoss", "ultralytics.utils.loss.v8DetectionLoss", "ultralytics.utils.loss.v8OBBLoss", "ultralytics.utils.loss.v8PoseLoss", "ultralytics.utils.loss.v8SegmentationLoss", "ultralytics.utils.ops.make_divisible", "ultralytics.utils.patches.torch_load", "ultralytics.utils.plotting.feature_visualization", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "ultralytics.utils.torch_utils.fuse_deconv_and_bn", "ultralytics.utils.torch_utils.initialize_weights", "ultralytics.utils.torch_utils.intersect_dicts", "ultralytics.utils.torch_utils.model_info", "ultralytics.utils.torch_utils.scale_img", "ultralytics.utils.torch_utils.smart_inference_mode", "ultralytics.utils.torch_utils.time_sync", "sys", "importlib.import_module", "ultralytics.utils.downloads.attempt_download_asset", "ast", "ultralytics.models.utils.loss.RTDETRDetectionLoss", "ultralytics.nn.text_model.build_text_model", "ultralytics.nn.text_model.build_text_model", "thop", "ultralytics.utils.loss.TVPDetectLoss", "ultralytics.utils.loss.TVPSegmentLoss", "BaseModel" ], "chunk_id": "class_ClassificationModel_59fcf047" }, { "content": "class RTDETRDetectionModel(DetectionModel):\n \"\"\"\n RTDETR (Real-time DEtection and Tracking using Transformers) Detection Model class.\n\n This class is responsible for constructing the RTDETR architecture, defining loss functions, and facilitating both\n the training and inference processes. RTDETR is an object detection and tracking model that extends from the\n DetectionModel base class.\n\n Attributes:\n nc (int): Number of classes for detection.\n criterion (RTDETRDetectionLoss): Loss function for training.\n\n Methods:\n __init__: Initialize the RTDETRDetectionModel.\n init_criterion: Initialize the loss criterion.\n loss: Compute loss for training.\n predict: Perform forward pass through the model.\n\n Examples:\n Initialize an RTDETR model\n >>> model = RTDETRDetectionModel(\"rtdetr-l.yaml\", ch=3, nc=80)\n >>> results = model.predict(image_tensor)\n \"\"\"\n\n def __init__(self, cfg=\"rtdetr-l.yaml\", ch=3, nc=None, verbose=True):\n \"\"\"\n Initialize the RTDETRDetectionModel.\n\n Args:\n cfg (str | dict): Configuration file name or path.\n ch (int): Number of input channels.\n nc (int, optional): Number of classes.\n verbose (bool): Print additional information during initialization.\n \"\"\"\n super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)\n\n def init_criterion(self):\n \"\"\"Initialize the loss criterion for the RTDETRDetectionModel.\"\"\"\n from ultralytics.models.utils.loss import RTDETRDetectionLoss\n\n return RTDETRDetectionLoss(nc=self.nc, use_vfl=True)\n\n def loss(self, batch, preds=None):\n \"\"\"\n Compute the loss for the given batch of data.\n\n Args:\n batch (dict): Dictionary containing image and label data.\n preds (torch.Tensor, optional): Precomputed model predictions.\n\n Returns:\n loss_sum (torch.Tensor): Total loss value.\n loss_items (torch.Tensor): Main three losses in a tensor.\n \"\"\"\n if not hasattr(self, \"criterion\"):\n self.criterion = self.init_criterion()\n\n img = batch[\"img\"]\n # NOTE: preprocess gt_bbox and gt_labels to list.\n bs = len(img)\n batch_idx = batch[\"batch_idx\"]\n gt_groups = [(batch_idx == i).sum().item() for i in range(bs)]\n targets = {\n \"cls\": batch[\"cls\"].to(img.device, dtype=torch.long).view(-1),\n \"bboxes\": batch[\"bboxes\"].to(device=img.device),\n \"batch_idx\": batch_idx.to(img.device, dtype=torch.long).view(-1),\n \"gt_groups\": gt_groups,\n }\n\n preds = self.predict(img, batch=targets) if preds is None else preds\n dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta = preds if self.training else preds[1]\n if dn_meta is None:\n dn_bboxes, dn_scores = None, None\n else:\n dn_bboxes, dec_bboxes = torch.split(dec_bboxes, dn_meta[\"dn_num_split\"], dim=2)\n dn_scores, dec_scores = torch.split(dec_scores, dn_meta[\"dn_num_split\"], dim=2)\n\n dec_bboxes = torch.cat([enc_bboxes.unsqueeze(0), dec_bboxes]) # (7, bs, 300, 4)\n dec_scores = torch.cat([enc_scores.unsqueeze(0), dec_scores])\n\n loss = self.criterion(\n (dec_bboxes, dec_scores), targets, dn_bboxes=dn_bboxes, dn_scores=dn_scores, dn_meta=dn_meta\n )\n # NOTE: There are like 12 losses in RTDETR, backward with all losses but only show the main three losses.\n return sum(loss.values()), torch.as_tensor(\n [loss[k].detach() for k in [\"loss_giou\", \"loss_class\", \"loss_bbox\"]], device=img.device\n )\n\n def predict(self, x, profile=False, visualize=False, batch=None, augment=False, embed=None):\n \"\"\"\n Perform a forward pass through the model.\n\n Args:\n x (torch.Tensor): The input tensor.\n profile (bool): If True, profile the computation time for each layer.\n visualize (bool): If True, save feature maps for visualization.\n batch (dict, optional): Ground truth data for evaluation.\n augment (bool): If True, perform data augmentation during inference.\n embed (list, optional): A list of feature vectors/embeddings to return.\n\n Returns:\n (torch.Tensor): Model's output tensor.\n \"\"\"\n y, dt, embeddings = [], [], [] # outputs\n embed = frozenset(embed) if embed is not None else {-1}\n max_idx = max(embed)\n for m in self.model[:-1]: # except the head part\n if m.f != -1: # if not from previous layer\n x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers\n if profile:\n self._profile_one_layer(m, x, dt)\n x = m(x) # run\n y.append(x if m.i in self.save else None) # save output\n if visualize:\n feature_visualization(x, m.type, m.i, save_dir=visualize)\n if m.i in embed:\n embeddings.append(torch.nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten\n if m.i == max_idx:\n return torch.unbind(torch.cat(embeddings, 1), dim=0)\n head = self.model[-1]\n x = head([y[j] for j in head.f], batch) # head inference\n return x", "chunk_type": "class", "name": "RTDETRDetectionModel", "file_path": "ultralytics\\ultralytics\\nn\\tasks.py", "start_line": 711, "end_line": 832, "start_col": 0, "end_col": 16, "parent_name": null, "docstring": "RTDETR (Real-time DEtection and Tracking using Transformers) Detection Model class.\n\nThis class is responsible for constructing the RTDETR architecture, defining loss functions, and facilitating both\nthe training and inference processes. RTDETR is an object detection and tracking model that extends from the\nDetectionModel base class.\n\nAttributes:\n nc (int): Number of classes for detection.\n criterion (RTDETRDetectionLoss): Loss function for training.\n\nMethods:\n __init__: Initialize the RTDETRDetectionModel.\n init_criterion: Initialize the loss criterion.\n loss: Compute loss for training.\n predict: Perform forward pass through the model.\n\nExamples:\n Initialize an RTDETR model\n >>> model = RTDETRDetectionModel(\"rtdetr-l.yaml\", ch=3, nc=80)\n >>> results = model.predict(image_tensor)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "contextlib", "pickle", "re", "types", "copy.deepcopy", "pathlib.Path", "torch", "torch.nn", "ultralytics.nn.autobackend.check_class_names", "ultralytics.nn.modules.AIFI", "ultralytics.nn.modules.C1", "ultralytics.nn.modules.C2", "ultralytics.nn.modules.C2PSA", "ultralytics.nn.modules.C3", "ultralytics.nn.modules.C3TR", "ultralytics.nn.modules.ELAN1", "ultralytics.nn.modules.OBB", "ultralytics.nn.modules.PSA", "ultralytics.nn.modules.SPP", "ultralytics.nn.modules.SPPELAN", "ultralytics.nn.modules.SPPF", "ultralytics.nn.modules.A2C2f", "ultralytics.nn.modules.AConv", "ultralytics.nn.modules.ADown", "ultralytics.nn.modules.Bottleneck", "ultralytics.nn.modules.BottleneckCSP", "ultralytics.nn.modules.C2f", "ultralytics.nn.modules.C2fAttn", "ultralytics.nn.modules.C2fCIB", "ultralytics.nn.modules.C2fPSA", "ultralytics.nn.modules.C3Ghost", "ultralytics.nn.modules.C3k2", "ultralytics.nn.modules.C3x", "ultralytics.nn.modules.CBFuse", "ultralytics.nn.modules.CBLinear", "ultralytics.nn.modules.Classify", "ultralytics.nn.modules.Concat", "ultralytics.nn.modules.Conv", "ultralytics.nn.modules.Conv2", "ultralytics.nn.modules.ConvTranspose", "ultralytics.nn.modules.Detect", "ultralytics.nn.modules.DWConv", "ultralytics.nn.modules.DWConvTranspose2d", "ultralytics.nn.modules.Focus", "ultralytics.nn.modules.GhostBottleneck", "ultralytics.nn.modules.GhostConv", "ultralytics.nn.modules.HGBlock", "ultralytics.nn.modules.HGStem", "ultralytics.nn.modules.ImagePoolingAttn", "ultralytics.nn.modules.Index", "ultralytics.nn.modules.LRPCHead", "ultralytics.nn.modules.Pose", "ultralytics.nn.modules.RepC3", "ultralytics.nn.modules.RepConv", "ultralytics.nn.modules.RepNCSPELAN4", "ultralytics.nn.modules.RepVGGDW", "ultralytics.nn.modules.ResNetLayer", "ultralytics.nn.modules.RTDETRDecoder", "ultralytics.nn.modules.SCDown", "ultralytics.nn.modules.Segment", "ultralytics.nn.modules.TorchVision", "ultralytics.nn.modules.WorldDetect", "ultralytics.nn.modules.YOLOEDetect", "ultralytics.nn.modules.YOLOESegment", "ultralytics.nn.modules.v10Detect", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.DEFAULT_CFG_KEYS", "ultralytics.utils.LOGGER", "ultralytics.utils.YAML", "ultralytics.utils.colorstr", "ultralytics.utils.emojis", "ultralytics.utils.checks.check_requirements", "ultralytics.utils.checks.check_suffix", "ultralytics.utils.checks.check_yaml", "ultralytics.utils.loss.E2EDetectLoss", "ultralytics.utils.loss.v8ClassificationLoss", "ultralytics.utils.loss.v8DetectionLoss", "ultralytics.utils.loss.v8OBBLoss", "ultralytics.utils.loss.v8PoseLoss", "ultralytics.utils.loss.v8SegmentationLoss", "ultralytics.utils.ops.make_divisible", "ultralytics.utils.patches.torch_load", "ultralytics.utils.plotting.feature_visualization", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "ultralytics.utils.torch_utils.fuse_deconv_and_bn", "ultralytics.utils.torch_utils.initialize_weights", "ultralytics.utils.torch_utils.intersect_dicts", "ultralytics.utils.torch_utils.model_info", "ultralytics.utils.torch_utils.scale_img", "ultralytics.utils.torch_utils.smart_inference_mode", "ultralytics.utils.torch_utils.time_sync", "sys", "importlib.import_module", "ultralytics.utils.downloads.attempt_download_asset", "ast", "ultralytics.models.utils.loss.RTDETRDetectionLoss", "ultralytics.nn.text_model.build_text_model", "ultralytics.nn.text_model.build_text_model", "thop", "ultralytics.utils.loss.TVPDetectLoss", "ultralytics.utils.loss.TVPSegmentLoss", "DetectionModel" ], "chunk_id": "class_RTDETRDetectionModel_3fe7c421" }, { "content": "class WorldModel(DetectionModel):\n \"\"\"\n YOLOv8 World Model.\n\n This class implements the YOLOv8 World model for open-vocabulary object detection, supporting text-based\n class specification and CLIP model integration for zero-shot detection capabilities.\n\n Attributes:\n txt_feats (torch.Tensor): Text feature embeddings for classes.\n clip_model (torch.nn.Module): CLIP model for text encoding.\n\n Methods:\n __init__: Initialize YOLOv8 world model.\n set_classes: Set classes for offline inference.\n get_text_pe: Get text positional embeddings.\n predict: Perform forward pass with text features.\n loss: Compute loss with text features.\n\n Examples:\n Initialize a world model\n >>> model = WorldModel(\"yolov8s-world.yaml\", ch=3, nc=80)\n >>> model.set_classes([\"person\", \"car\", \"bicycle\"])\n >>> results = model.predict(image_tensor)\n \"\"\"\n\n def __init__(self, cfg=\"yolov8s-world.yaml\", ch=3, nc=None, verbose=True):\n \"\"\"\n Initialize YOLOv8 world model with given config and parameters.\n\n Args:\n cfg (str | dict): Model configuration file path or dictionary.\n ch (int): Number of input channels.\n nc (int, optional): Number of classes.\n verbose (bool): Whether to display model information.\n \"\"\"\n self.txt_feats = torch.randn(1, nc or 80, 512) # features placeholder\n self.clip_model = None # CLIP model placeholder\n super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)\n\n def set_classes(self, text, batch=80, cache_clip_model=True):\n \"\"\"\n Set classes in advance so that model could do offline-inference without clip model.\n\n Args:\n text (List[str]): List of class names.\n batch (int): Batch size for processing text tokens.\n cache_clip_model (bool): Whether to cache the CLIP model.\n \"\"\"\n self.txt_feats = self.get_text_pe(text, batch=batch, cache_clip_model=cache_clip_model)\n self.model[-1].nc = len(text)\n\n def get_text_pe(self, text, batch=80, cache_clip_model=True):\n \"\"\"\n Set classes in advance so that model could do offline-inference without clip model.\n\n Args:\n text (List[str]): List of class names.\n batch (int): Batch size for processing text tokens.\n cache_clip_model (bool): Whether to cache the CLIP model.\n\n Returns:\n (torch.Tensor): Text positional embeddings.\n \"\"\"\n from ultralytics.nn.text_model import build_text_model\n\n device = next(self.model.parameters()).device\n if not getattr(self, \"clip_model\", None) and cache_clip_model:\n # For backwards compatibility of models lacking clip_model attribute\n self.clip_model = build_text_model(\"clip:ViT-B/32\", device=device)\n model = self.clip_model if cache_clip_model else build_text_model(\"clip:ViT-B/32\", device=device)\n text_token = model.tokenize(text)\n txt_feats = [model.encode_text(token).detach() for token in text_token.split(batch)]\n txt_feats = txt_feats[0] if len(txt_feats) == 1 else torch.cat(txt_feats, dim=0)\n return txt_feats.reshape(-1, len(text), txt_feats.shape[-1])\n\n def predict(self, x, profile=False, visualize=False, txt_feats=None, augment=False, embed=None):\n \"\"\"\n Perform a forward pass through the model.\n\n Args:\n x (torch.Tensor): The input tensor.\n profile (bool): If True, profile the computation time for each layer.\n visualize (bool): If True, save feature maps for visualization.\n txt_feats (torch.Tensor, optional): The text features, use it if it's given.\n augment (bool): If True, perform data augmentation during inference.\n embed (list, optional): A list of feature vectors/embeddings to return.\n\n Returns:\n (torch.Tensor): Model's output tensor.\n \"\"\"\n txt_feats = (self.txt_feats if txt_feats is None else txt_feats).to(device=x.device, dtype=x.dtype)\n if len(txt_feats) != len(x) or self.model[-1].export:\n txt_feats = txt_feats.expand(x.shape[0], -1, -1)\n ori_txt_feats = txt_feats.clone()\n y, dt, embeddings = [], [], [] # outputs\n embed = frozenset(embed) if embed is not None else {-1}\n max_idx = max(embed)\n for m in self.model: # except the head part\n if m.f != -1: # if not from previous layer\n x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers\n if profile:\n self._profile_one_layer(m, x, dt)\n if isinstance(m, C2fAttn):\n x = m(x, txt_feats)\n elif isinstance(m, WorldDetect):\n x = m(x, ori_txt_feats)\n elif isinstance(m, ImagePoolingAttn):\n txt_feats = m(x, txt_feats)\n else:\n x = m(x) # run\n\n y.append(x if m.i in self.save else None) # save output\n if visualize:\n feature_visualization(x, m.type, m.i, save_dir=visualize)\n if m.i in embed:\n embeddings.append(torch.nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten\n if m.i == max_idx:\n return torch.unbind(torch.cat(embeddings, 1), dim=0)\n return x\n\n def loss(self, batch, preds=None):\n \"\"\"\n Compute loss.\n\n Args:\n batch (dict): Batch to compute loss on.\n preds (torch.Tensor | List[torch.Tensor], optional): Predictions.\n \"\"\"\n if not hasattr(self, \"criterion\"):\n self.criterion = self.init_criterion()\n\n if preds is None:\n preds = self.forward(batch[\"img\"], txt_feats=batch[\"txt_feats\"])\n return self.criterion(preds, batch)", "chunk_type": "class", "name": "WorldModel", "file_path": "ultralytics\\ultralytics\\nn\\tasks.py", "start_line": 835, "end_line": 968, "start_col": 0, "end_col": 43, "parent_name": null, "docstring": "YOLOv8 World Model.\n\nThis class implements the YOLOv8 World model for open-vocabulary object detection, supporting text-based\nclass specification and CLIP model integration for zero-shot detection capabilities.\n\nAttributes:\n txt_feats (torch.Tensor): Text feature embeddings for classes.\n clip_model (torch.nn.Module): CLIP model for text encoding.\n\nMethods:\n __init__: Initialize YOLOv8 world model.\n set_classes: Set classes for offline inference.\n get_text_pe: Get text positional embeddings.\n predict: Perform forward pass with text features.\n loss: Compute loss with text features.\n\nExamples:\n Initialize a world model\n >>> model = WorldModel(\"yolov8s-world.yaml\", ch=3, nc=80)\n >>> model.set_classes([\"person\", \"car\", \"bicycle\"])\n >>> results = model.predict(image_tensor)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "contextlib", "pickle", "re", "types", "copy.deepcopy", "pathlib.Path", "torch", "torch.nn", "ultralytics.nn.autobackend.check_class_names", "ultralytics.nn.modules.AIFI", "ultralytics.nn.modules.C1", "ultralytics.nn.modules.C2", "ultralytics.nn.modules.C2PSA", "ultralytics.nn.modules.C3", "ultralytics.nn.modules.C3TR", "ultralytics.nn.modules.ELAN1", "ultralytics.nn.modules.OBB", "ultralytics.nn.modules.PSA", "ultralytics.nn.modules.SPP", "ultralytics.nn.modules.SPPELAN", "ultralytics.nn.modules.SPPF", "ultralytics.nn.modules.A2C2f", "ultralytics.nn.modules.AConv", "ultralytics.nn.modules.ADown", "ultralytics.nn.modules.Bottleneck", "ultralytics.nn.modules.BottleneckCSP", "ultralytics.nn.modules.C2f", "ultralytics.nn.modules.C2fAttn", "ultralytics.nn.modules.C2fCIB", "ultralytics.nn.modules.C2fPSA", "ultralytics.nn.modules.C3Ghost", "ultralytics.nn.modules.C3k2", "ultralytics.nn.modules.C3x", "ultralytics.nn.modules.CBFuse", "ultralytics.nn.modules.CBLinear", "ultralytics.nn.modules.Classify", "ultralytics.nn.modules.Concat", "ultralytics.nn.modules.Conv", "ultralytics.nn.modules.Conv2", "ultralytics.nn.modules.ConvTranspose", "ultralytics.nn.modules.Detect", "ultralytics.nn.modules.DWConv", "ultralytics.nn.modules.DWConvTranspose2d", "ultralytics.nn.modules.Focus", "ultralytics.nn.modules.GhostBottleneck", "ultralytics.nn.modules.GhostConv", "ultralytics.nn.modules.HGBlock", "ultralytics.nn.modules.HGStem", "ultralytics.nn.modules.ImagePoolingAttn", "ultralytics.nn.modules.Index", "ultralytics.nn.modules.LRPCHead", "ultralytics.nn.modules.Pose", "ultralytics.nn.modules.RepC3", "ultralytics.nn.modules.RepConv", "ultralytics.nn.modules.RepNCSPELAN4", "ultralytics.nn.modules.RepVGGDW", "ultralytics.nn.modules.ResNetLayer", "ultralytics.nn.modules.RTDETRDecoder", "ultralytics.nn.modules.SCDown", "ultralytics.nn.modules.Segment", "ultralytics.nn.modules.TorchVision", "ultralytics.nn.modules.WorldDetect", "ultralytics.nn.modules.YOLOEDetect", "ultralytics.nn.modules.YOLOESegment", "ultralytics.nn.modules.v10Detect", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.DEFAULT_CFG_KEYS", "ultralytics.utils.LOGGER", "ultralytics.utils.YAML", "ultralytics.utils.colorstr", "ultralytics.utils.emojis", "ultralytics.utils.checks.check_requirements", "ultralytics.utils.checks.check_suffix", "ultralytics.utils.checks.check_yaml", "ultralytics.utils.loss.E2EDetectLoss", "ultralytics.utils.loss.v8ClassificationLoss", "ultralytics.utils.loss.v8DetectionLoss", "ultralytics.utils.loss.v8OBBLoss", "ultralytics.utils.loss.v8PoseLoss", "ultralytics.utils.loss.v8SegmentationLoss", "ultralytics.utils.ops.make_divisible", "ultralytics.utils.patches.torch_load", "ultralytics.utils.plotting.feature_visualization", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "ultralytics.utils.torch_utils.fuse_deconv_and_bn", "ultralytics.utils.torch_utils.initialize_weights", "ultralytics.utils.torch_utils.intersect_dicts", "ultralytics.utils.torch_utils.model_info", "ultralytics.utils.torch_utils.scale_img", "ultralytics.utils.torch_utils.smart_inference_mode", "ultralytics.utils.torch_utils.time_sync", "sys", "importlib.import_module", "ultralytics.utils.downloads.attempt_download_asset", "ast", "ultralytics.models.utils.loss.RTDETRDetectionLoss", "ultralytics.nn.text_model.build_text_model", "ultralytics.nn.text_model.build_text_model", "thop", "ultralytics.utils.loss.TVPDetectLoss", "ultralytics.utils.loss.TVPSegmentLoss", "DetectionModel" ], "chunk_id": "class_WorldModel_a652ec58" }, { "content": "class YOLOEModel(DetectionModel):\n \"\"\"\n YOLOE detection model.\n\n This class implements the YOLOE architecture for efficient object detection with text and visual prompts,\n supporting both prompt-based and prompt-free inference modes.\n\n Attributes:\n pe (torch.Tensor): Prompt embeddings for classes.\n clip_model (torch.nn.Module): CLIP model for text encoding.\n\n Methods:\n __init__: Initialize YOLOE model.\n get_text_pe: Get text positional embeddings.\n get_visual_pe: Get visual embeddings.\n set_vocab: Set vocabulary for prompt-free model.\n get_vocab: Get fused vocabulary layer.\n set_classes: Set classes for offline inference.\n get_cls_pe: Get class positional embeddings.\n predict: Perform forward pass with prompts.\n loss: Compute loss with prompts.\n\n Examples:\n Initialize a YOLOE model\n >>> model = YOLOEModel(\"yoloe-v8s.yaml\", ch=3, nc=80)\n >>> results = model.predict(image_tensor, tpe=text_embeddings)\n \"\"\"\n\n def __init__(self, cfg=\"yoloe-v8s.yaml\", ch=3, nc=None, verbose=True):\n \"\"\"\n Initialize YOLOE model with given config and parameters.\n\n Args:\n cfg (str | dict): Model configuration file path or dictionary.\n ch (int): Number of input channels.\n nc (int, optional): Number of classes.\n verbose (bool): Whether to display model information.\n \"\"\"\n super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)\n\n @smart_inference_mode()\n def get_text_pe(self, text, batch=80, cache_clip_model=False, without_reprta=False):\n \"\"\"\n Set classes in advance so that model could do offline-inference without clip model.\n\n Args:\n text (List[str]): List of class names.\n batch (int): Batch size for processing text tokens.\n cache_clip_model (bool): Whether to cache the CLIP model.\n without_reprta (bool): Whether to return text embeddings cooperated with reprta module.\n\n Returns:\n (torch.Tensor): Text positional embeddings.\n \"\"\"\n from ultralytics.nn.text_model import build_text_model\n\n device = next(self.model.parameters()).device\n if not getattr(self, \"clip_model\", None) and cache_clip_model:\n # For backwards compatibility of models lacking clip_model attribute\n self.clip_model = build_text_model(\"mobileclip:blt\", device=device)\n\n model = self.clip_model if cache_clip_model else build_text_model(\"mobileclip:blt\", device=device)\n text_token = model.tokenize(text)\n txt_feats = [model.encode_text(token).detach() for token in text_token.split(batch)]\n txt_feats = txt_feats[0] if len(txt_feats) == 1 else torch.cat(txt_feats, dim=0)\n txt_feats = txt_feats.reshape(-1, len(text), txt_feats.shape[-1])\n if without_reprta:\n return txt_feats\n\n assert not self.training\n head = self.model[-1]\n assert isinstance(head, YOLOEDetect)\n return head.get_tpe(txt_feats) # run auxiliary text head\n\n @smart_inference_mode()\n def get_visual_pe(self, img, visual):\n \"\"\"\n Get visual embeddings.\n\n Args:\n img (torch.Tensor): Input image tensor.\n visual (torch.Tensor): Visual features.\n\n Returns:\n (torch.Tensor): Visual positional embeddings.\n \"\"\"\n return self(img, vpe=visual, return_vpe=True)\n\n def set_vocab(self, vocab, names):\n \"\"\"\n Set vocabulary for the prompt-free model.\n\n Args:\n vocab (nn.ModuleList): List of vocabulary items.\n names (List[str]): List of class names.\n \"\"\"\n assert not self.training\n head = self.model[-1]\n assert isinstance(head, YOLOEDetect)\n\n # Cache anchors for head\n device = next(self.parameters()).device\n self(torch.empty(1, 3, self.args[\"imgsz\"], self.args[\"imgsz\"]).to(device)) # warmup\n\n # re-parameterization for prompt-free model\n self.model[-1].lrpc = nn.ModuleList(\n LRPCHead(cls, pf[-1], loc[-1], enabled=i != 2)\n for i, (cls, pf, loc) in enumerate(zip(vocab, head.cv3, head.cv2))\n )\n for loc_head, cls_head in zip(head.cv2, head.cv3):\n assert isinstance(loc_head, nn.Sequential)\n assert isinstance(cls_head, nn.Sequential)\n del loc_head[-1]\n del cls_head[-1]\n self.model[-1].nc = len(names)\n self.names = check_class_names(names)\n\n def get_vocab(self, names):\n \"\"\"\n Get fused vocabulary layer from the model.\n\n Args:\n names (list): List of class names.\n\n Returns:\n (nn.ModuleList): List of vocabulary modules.\n \"\"\"\n assert not self.training\n head = self.model[-1]\n assert isinstance(head, YOLOEDetect)\n assert not head.is_fused\n\n tpe = self.get_text_pe(names)\n self.set_classes(names, tpe)\n device = next(self.model.parameters()).device\n head.fuse(self.pe.to(device)) # fuse prompt embeddings to classify head\n\n vocab = nn.ModuleList()\n for cls_head in head.cv3:\n assert isinstance(cls_head, nn.Sequential)\n vocab.append(cls_head[-1])\n return vocab\n\n def set_classes(self, names, embeddings):\n \"\"\"\n Set classes in advance so that model could do offline-inference without clip model.\n\n Args:\n names (List[str]): List of class names.\n embeddings (torch.Tensor): Embeddings tensor.\n \"\"\"\n assert not hasattr(self.model[-1], \"lrpc\"), (\n \"Prompt-free model does not support setting classes. Please try with Text/Visual prompt models.\"\n )\n assert embeddings.ndim == 3\n self.pe = embeddings\n self.model[-1].nc = len(names)\n self.names = check_class_names(names)\n\n def get_cls_pe(self, tpe, vpe):\n \"\"\"\n Get class positional embeddings.\n\n Args:\n tpe (torch.Tensor, optional): Text positional embeddings.\n vpe (torch.Tensor, optional): Visual positional embeddings.\n\n Returns:\n (torch.Tensor): Class positional embeddings.\n \"\"\"\n all_pe = []\n if tpe is not None:\n assert tpe.ndim == 3\n all_pe.append(tpe)\n if vpe is not None:\n assert vpe.ndim == 3\n all_pe.append(vpe)\n if not all_pe:\n all_pe.append(getattr(self, \"pe\", torch.zeros(1, 80, 512)))\n return torch.cat(all_pe, dim=1)\n\n def predict(\n self, x, profile=False, visualize=False, tpe=None, augment=False, embed=None, vpe=None, return_vpe=False\n ):\n \"\"\"\n Perform a forward pass through the model.\n\n Args:\n x (torch.Tensor): The input tensor.\n profile (bool): If True, profile the computation time for each layer.\n visualize (bool): If True, save feature maps for visualization.\n tpe (torch.Tensor, optional): Text positional embeddings.\n augment (bool): If True, perform data augmentation during inference.\n embed (list, optional): A list of feature vectors/embeddings to return.\n vpe (torch.Tensor, optional): Visual positional embeddings.\n return_vpe (bool): If True, return visual positional embeddings.\n\n Returns:\n (torch.Tensor): Model's output tensor.\n \"\"\"\n y, dt, embeddings = [], [], [] # outputs\n b = x.shape[0]\n embed = frozenset(embed) if embed is not None else {-1}\n max_idx = max(embed)\n for m in self.model: # except the head part\n if m.f != -1: # if not from previous layer\n x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers\n if profile:\n self._profile_one_layer(m, x, dt)\n if isinstance(m, YOLOEDetect):\n vpe = m.get_vpe(x, vpe) if vpe is not None else None\n if return_vpe:\n assert vpe is not None\n assert not self.training\n return vpe\n cls_pe = self.get_cls_pe(m.get_tpe(tpe), vpe).to(device=x[0].device, dtype=x[0].dtype)\n if cls_pe.shape[0] != b or m.export:\n cls_pe = cls_pe.expand(b, -1, -1)\n x = m(x, cls_pe)\n else:\n x = m(x) # run\n\n y.append(x if m.i in self.save else None) # save output\n if visualize:\n feature_visualization(x, m.type, m.i, save_dir=visualize)\n if m.i in embed:\n embeddings.append(torch.nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten\n if m.i == max_idx:\n return torch.unbind(torch.cat(embeddings, 1), dim=0)\n return x\n\n def loss(self, batch, preds=None):\n \"\"\"\n Compute loss.\n\n Args:\n batch (dict): Batch to compute loss on.\n preds (torch.Tensor | List[torch.Tensor], optional): Predictions.\n \"\"\"\n if not hasattr(self, \"criterion\"):\n from ultralytics.utils.loss import TVPDetectLoss\n\n visual_prompt = batch.get(\"visuals\", None) is not None # TODO\n self.criterion = TVPDetectLoss(self) if visual_prompt else self.init_criterion()\n\n if preds is None:\n preds = self.forward(batch[\"img\"], tpe=batch.get(\"txt_feats\", None), vpe=batch.get(\"visuals\", None))\n return self.criterion(preds, batch)", "chunk_type": "class", "name": "YOLOEModel", "file_path": "ultralytics\\ultralytics\\nn\\tasks.py", "start_line": 971, "end_line": 1218, "start_col": 0, "end_col": 43, "parent_name": null, "docstring": "YOLOE detection model.\n\nThis class implements the YOLOE architecture for efficient object detection with text and visual prompts,\nsupporting both prompt-based and prompt-free inference modes.\n\nAttributes:\n pe (torch.Tensor): Prompt embeddings for classes.\n clip_model (torch.nn.Module): CLIP model for text encoding.\n\nMethods:\n __init__: Initialize YOLOE model.\n get_text_pe: Get text positional embeddings.\n get_visual_pe: Get visual embeddings.\n set_vocab: Set vocabulary for prompt-free model.\n get_vocab: Get fused vocabulary layer.\n set_classes: Set classes for offline inference.\n get_cls_pe: Get class positional embeddings.\n predict: Perform forward pass with prompts.\n loss: Compute loss with prompts.\n\nExamples:\n Initialize a YOLOE model\n >>> model = YOLOEModel(\"yoloe-v8s.yaml\", ch=3, nc=80)\n >>> results = model.predict(image_tensor, tpe=text_embeddings)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "contextlib", "pickle", "re", "types", "copy.deepcopy", "pathlib.Path", "torch", "torch.nn", "ultralytics.nn.autobackend.check_class_names", "ultralytics.nn.modules.AIFI", "ultralytics.nn.modules.C1", "ultralytics.nn.modules.C2", "ultralytics.nn.modules.C2PSA", "ultralytics.nn.modules.C3", "ultralytics.nn.modules.C3TR", "ultralytics.nn.modules.ELAN1", "ultralytics.nn.modules.OBB", "ultralytics.nn.modules.PSA", "ultralytics.nn.modules.SPP", "ultralytics.nn.modules.SPPELAN", "ultralytics.nn.modules.SPPF", "ultralytics.nn.modules.A2C2f", "ultralytics.nn.modules.AConv", "ultralytics.nn.modules.ADown", "ultralytics.nn.modules.Bottleneck", "ultralytics.nn.modules.BottleneckCSP", "ultralytics.nn.modules.C2f", "ultralytics.nn.modules.C2fAttn", "ultralytics.nn.modules.C2fCIB", "ultralytics.nn.modules.C2fPSA", "ultralytics.nn.modules.C3Ghost", "ultralytics.nn.modules.C3k2", "ultralytics.nn.modules.C3x", "ultralytics.nn.modules.CBFuse", "ultralytics.nn.modules.CBLinear", "ultralytics.nn.modules.Classify", "ultralytics.nn.modules.Concat", "ultralytics.nn.modules.Conv", "ultralytics.nn.modules.Conv2", "ultralytics.nn.modules.ConvTranspose", "ultralytics.nn.modules.Detect", "ultralytics.nn.modules.DWConv", "ultralytics.nn.modules.DWConvTranspose2d", "ultralytics.nn.modules.Focus", "ultralytics.nn.modules.GhostBottleneck", "ultralytics.nn.modules.GhostConv", "ultralytics.nn.modules.HGBlock", "ultralytics.nn.modules.HGStem", "ultralytics.nn.modules.ImagePoolingAttn", "ultralytics.nn.modules.Index", "ultralytics.nn.modules.LRPCHead", "ultralytics.nn.modules.Pose", "ultralytics.nn.modules.RepC3", "ultralytics.nn.modules.RepConv", "ultralytics.nn.modules.RepNCSPELAN4", "ultralytics.nn.modules.RepVGGDW", "ultralytics.nn.modules.ResNetLayer", "ultralytics.nn.modules.RTDETRDecoder", "ultralytics.nn.modules.SCDown", "ultralytics.nn.modules.Segment", "ultralytics.nn.modules.TorchVision", "ultralytics.nn.modules.WorldDetect", "ultralytics.nn.modules.YOLOEDetect", "ultralytics.nn.modules.YOLOESegment", "ultralytics.nn.modules.v10Detect", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.DEFAULT_CFG_KEYS", "ultralytics.utils.LOGGER", "ultralytics.utils.YAML", "ultralytics.utils.colorstr", "ultralytics.utils.emojis", "ultralytics.utils.checks.check_requirements", "ultralytics.utils.checks.check_suffix", "ultralytics.utils.checks.check_yaml", "ultralytics.utils.loss.E2EDetectLoss", "ultralytics.utils.loss.v8ClassificationLoss", "ultralytics.utils.loss.v8DetectionLoss", "ultralytics.utils.loss.v8OBBLoss", "ultralytics.utils.loss.v8PoseLoss", "ultralytics.utils.loss.v8SegmentationLoss", "ultralytics.utils.ops.make_divisible", "ultralytics.utils.patches.torch_load", "ultralytics.utils.plotting.feature_visualization", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "ultralytics.utils.torch_utils.fuse_deconv_and_bn", "ultralytics.utils.torch_utils.initialize_weights", "ultralytics.utils.torch_utils.intersect_dicts", "ultralytics.utils.torch_utils.model_info", "ultralytics.utils.torch_utils.scale_img", "ultralytics.utils.torch_utils.smart_inference_mode", "ultralytics.utils.torch_utils.time_sync", "sys", "importlib.import_module", "ultralytics.utils.downloads.attempt_download_asset", "ast", "ultralytics.models.utils.loss.RTDETRDetectionLoss", "ultralytics.nn.text_model.build_text_model", "ultralytics.nn.text_model.build_text_model", "thop", "ultralytics.utils.loss.TVPDetectLoss", "ultralytics.utils.loss.TVPSegmentLoss", "DetectionModel" ], "chunk_id": "class_YOLOEModel_d75e0d90" }, { "content": "class YOLOESegModel(YOLOEModel, SegmentationModel):\n \"\"\"\n YOLOE segmentation model.\n\n This class extends YOLOEModel to handle instance segmentation tasks with text and visual prompts,\n providing specialized loss computation for pixel-level object detection and segmentation.\n\n Methods:\n __init__: Initialize YOLOE segmentation model.\n loss: Compute loss with prompts for segmentation.\n\n Examples:\n Initialize a YOLOE segmentation model\n >>> model = YOLOESegModel(\"yoloe-v8s-seg.yaml\", ch=3, nc=80)\n >>> results = model.predict(image_tensor, tpe=text_embeddings)\n \"\"\"\n\n def __init__(self, cfg=\"yoloe-v8s-seg.yaml\", ch=3, nc=None, verbose=True):\n \"\"\"\n Initialize YOLOE segmentation model with given config and parameters.\n\n Args:\n cfg (str | dict): Model configuration file path or dictionary.\n ch (int): Number of input channels.\n nc (int, optional): Number of classes.\n verbose (bool): Whether to display model information.\n \"\"\"\n super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)\n\n def loss(self, batch, preds=None):\n \"\"\"\n Compute loss.\n\n Args:\n batch (dict): Batch to compute loss on.\n preds (torch.Tensor | List[torch.Tensor], optional): Predictions.\n \"\"\"\n if not hasattr(self, \"criterion\"):\n from ultralytics.utils.loss import TVPSegmentLoss\n\n visual_prompt = batch.get(\"visuals\", None) is not None # TODO\n self.criterion = TVPSegmentLoss(self) if visual_prompt else self.init_criterion()\n\n if preds is None:\n preds = self.forward(batch[\"img\"], tpe=batch.get(\"txt_feats\", None), vpe=batch.get(\"visuals\", None))\n return self.criterion(preds, batch)", "chunk_type": "class", "name": "YOLOESegModel", "file_path": "ultralytics\\ultralytics\\nn\\tasks.py", "start_line": 1221, "end_line": 1266, "start_col": 0, "end_col": 43, "parent_name": null, "docstring": "YOLOE segmentation model.\n\nThis class extends YOLOEModel to handle instance segmentation tasks with text and visual prompts,\nproviding specialized loss computation for pixel-level object detection and segmentation.\n\nMethods:\n __init__: Initialize YOLOE segmentation model.\n loss: Compute loss with prompts for segmentation.\n\nExamples:\n Initialize a YOLOE segmentation model\n >>> model = YOLOESegModel(\"yoloe-v8s-seg.yaml\", ch=3, nc=80)\n >>> results = model.predict(image_tensor, tpe=text_embeddings)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "contextlib", "pickle", "re", "types", "copy.deepcopy", "pathlib.Path", "torch", "torch.nn", "ultralytics.nn.autobackend.check_class_names", "ultralytics.nn.modules.AIFI", "ultralytics.nn.modules.C1", "ultralytics.nn.modules.C2", "ultralytics.nn.modules.C2PSA", "ultralytics.nn.modules.C3", "ultralytics.nn.modules.C3TR", "ultralytics.nn.modules.ELAN1", "ultralytics.nn.modules.OBB", "ultralytics.nn.modules.PSA", "ultralytics.nn.modules.SPP", "ultralytics.nn.modules.SPPELAN", "ultralytics.nn.modules.SPPF", "ultralytics.nn.modules.A2C2f", "ultralytics.nn.modules.AConv", "ultralytics.nn.modules.ADown", "ultralytics.nn.modules.Bottleneck", "ultralytics.nn.modules.BottleneckCSP", "ultralytics.nn.modules.C2f", "ultralytics.nn.modules.C2fAttn", "ultralytics.nn.modules.C2fCIB", "ultralytics.nn.modules.C2fPSA", "ultralytics.nn.modules.C3Ghost", "ultralytics.nn.modules.C3k2", "ultralytics.nn.modules.C3x", "ultralytics.nn.modules.CBFuse", "ultralytics.nn.modules.CBLinear", "ultralytics.nn.modules.Classify", "ultralytics.nn.modules.Concat", "ultralytics.nn.modules.Conv", "ultralytics.nn.modules.Conv2", "ultralytics.nn.modules.ConvTranspose", "ultralytics.nn.modules.Detect", "ultralytics.nn.modules.DWConv", "ultralytics.nn.modules.DWConvTranspose2d", "ultralytics.nn.modules.Focus", "ultralytics.nn.modules.GhostBottleneck", "ultralytics.nn.modules.GhostConv", "ultralytics.nn.modules.HGBlock", "ultralytics.nn.modules.HGStem", "ultralytics.nn.modules.ImagePoolingAttn", "ultralytics.nn.modules.Index", "ultralytics.nn.modules.LRPCHead", "ultralytics.nn.modules.Pose", "ultralytics.nn.modules.RepC3", "ultralytics.nn.modules.RepConv", "ultralytics.nn.modules.RepNCSPELAN4", "ultralytics.nn.modules.RepVGGDW", "ultralytics.nn.modules.ResNetLayer", "ultralytics.nn.modules.RTDETRDecoder", "ultralytics.nn.modules.SCDown", "ultralytics.nn.modules.Segment", "ultralytics.nn.modules.TorchVision", "ultralytics.nn.modules.WorldDetect", "ultralytics.nn.modules.YOLOEDetect", "ultralytics.nn.modules.YOLOESegment", "ultralytics.nn.modules.v10Detect", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.DEFAULT_CFG_KEYS", "ultralytics.utils.LOGGER", "ultralytics.utils.YAML", "ultralytics.utils.colorstr", "ultralytics.utils.emojis", "ultralytics.utils.checks.check_requirements", "ultralytics.utils.checks.check_suffix", "ultralytics.utils.checks.check_yaml", "ultralytics.utils.loss.E2EDetectLoss", "ultralytics.utils.loss.v8ClassificationLoss", "ultralytics.utils.loss.v8DetectionLoss", "ultralytics.utils.loss.v8OBBLoss", "ultralytics.utils.loss.v8PoseLoss", "ultralytics.utils.loss.v8SegmentationLoss", "ultralytics.utils.ops.make_divisible", "ultralytics.utils.patches.torch_load", "ultralytics.utils.plotting.feature_visualization", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "ultralytics.utils.torch_utils.fuse_deconv_and_bn", "ultralytics.utils.torch_utils.initialize_weights", "ultralytics.utils.torch_utils.intersect_dicts", "ultralytics.utils.torch_utils.model_info", "ultralytics.utils.torch_utils.scale_img", "ultralytics.utils.torch_utils.smart_inference_mode", "ultralytics.utils.torch_utils.time_sync", "sys", "importlib.import_module", "ultralytics.utils.downloads.attempt_download_asset", "ast", "ultralytics.models.utils.loss.RTDETRDetectionLoss", "ultralytics.nn.text_model.build_text_model", "ultralytics.nn.text_model.build_text_model", "thop", "ultralytics.utils.loss.TVPDetectLoss", "ultralytics.utils.loss.TVPSegmentLoss", "YOLOEModel", "SegmentationModel" ], "chunk_id": "class_YOLOESegModel_464eb2aa" }, { "content": "class Ensemble(torch.nn.ModuleList):\n \"\"\"\n Ensemble of models.\n\n This class allows combining multiple YOLO models into an ensemble for improved performance through\n model averaging or other ensemble techniques.\n\n Methods:\n __init__: Initialize an ensemble of models.\n forward: Generate predictions from all models in the ensemble.\n\n Examples:\n Create an ensemble of models\n >>> ensemble = Ensemble()\n >>> ensemble.append(model1)\n >>> ensemble.append(model2)\n >>> results = ensemble(image_tensor)\n \"\"\"\n\n def __init__(self):\n \"\"\"Initialize an ensemble of models.\"\"\"\n super().__init__()\n\n def forward(self, x, augment=False, profile=False, visualize=False):\n \"\"\"\n Generate the YOLO network's final layer.\n\n Args:\n x (torch.Tensor): Input tensor.\n augment (bool): Whether to augment the input.\n profile (bool): Whether to profile the model.\n visualize (bool): Whether to visualize the features.\n\n Returns:\n y (torch.Tensor): Concatenated predictions from all models.\n train_out (None): Always None for ensemble inference.\n \"\"\"\n y = [module(x, augment, profile, visualize)[0] for module in self]\n # y = torch.stack(y).max(0)[0] # max ensemble\n # y = torch.stack(y).mean(0) # mean ensemble\n y = torch.cat(y, 2) # nms ensemble, y shape(B, HW, C)\n return y, None # inference, train output", "chunk_type": "class", "name": "Ensemble", "file_path": "ultralytics\\ultralytics\\nn\\tasks.py", "start_line": 1269, "end_line": 1310, "start_col": 0, "end_col": 22, "parent_name": null, "docstring": "Ensemble of models.\n\nThis class allows combining multiple YOLO models into an ensemble for improved performance through\nmodel averaging or other ensemble techniques.\n\nMethods:\n __init__: Initialize an ensemble of models.\n forward: Generate predictions from all models in the ensemble.\n\nExamples:\n Create an ensemble of models\n >>> ensemble = Ensemble()\n >>> ensemble.append(model1)\n >>> ensemble.append(model2)\n >>> results = ensemble(image_tensor)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "contextlib", "pickle", "re", "types", "copy.deepcopy", "pathlib.Path", "torch", "torch.nn", "ultralytics.nn.autobackend.check_class_names", "ultralytics.nn.modules.AIFI", "ultralytics.nn.modules.C1", "ultralytics.nn.modules.C2", "ultralytics.nn.modules.C2PSA", "ultralytics.nn.modules.C3", "ultralytics.nn.modules.C3TR", "ultralytics.nn.modules.ELAN1", "ultralytics.nn.modules.OBB", "ultralytics.nn.modules.PSA", "ultralytics.nn.modules.SPP", "ultralytics.nn.modules.SPPELAN", "ultralytics.nn.modules.SPPF", "ultralytics.nn.modules.A2C2f", "ultralytics.nn.modules.AConv", "ultralytics.nn.modules.ADown", "ultralytics.nn.modules.Bottleneck", "ultralytics.nn.modules.BottleneckCSP", "ultralytics.nn.modules.C2f", "ultralytics.nn.modules.C2fAttn", "ultralytics.nn.modules.C2fCIB", "ultralytics.nn.modules.C2fPSA", "ultralytics.nn.modules.C3Ghost", "ultralytics.nn.modules.C3k2", "ultralytics.nn.modules.C3x", "ultralytics.nn.modules.CBFuse", "ultralytics.nn.modules.CBLinear", "ultralytics.nn.modules.Classify", "ultralytics.nn.modules.Concat", "ultralytics.nn.modules.Conv", "ultralytics.nn.modules.Conv2", "ultralytics.nn.modules.ConvTranspose", "ultralytics.nn.modules.Detect", "ultralytics.nn.modules.DWConv", "ultralytics.nn.modules.DWConvTranspose2d", "ultralytics.nn.modules.Focus", "ultralytics.nn.modules.GhostBottleneck", "ultralytics.nn.modules.GhostConv", "ultralytics.nn.modules.HGBlock", "ultralytics.nn.modules.HGStem", "ultralytics.nn.modules.ImagePoolingAttn", "ultralytics.nn.modules.Index", "ultralytics.nn.modules.LRPCHead", "ultralytics.nn.modules.Pose", "ultralytics.nn.modules.RepC3", "ultralytics.nn.modules.RepConv", "ultralytics.nn.modules.RepNCSPELAN4", "ultralytics.nn.modules.RepVGGDW", "ultralytics.nn.modules.ResNetLayer", "ultralytics.nn.modules.RTDETRDecoder", "ultralytics.nn.modules.SCDown", "ultralytics.nn.modules.Segment", "ultralytics.nn.modules.TorchVision", "ultralytics.nn.modules.WorldDetect", "ultralytics.nn.modules.YOLOEDetect", "ultralytics.nn.modules.YOLOESegment", "ultralytics.nn.modules.v10Detect", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.DEFAULT_CFG_KEYS", "ultralytics.utils.LOGGER", "ultralytics.utils.YAML", "ultralytics.utils.colorstr", "ultralytics.utils.emojis", "ultralytics.utils.checks.check_requirements", "ultralytics.utils.checks.check_suffix", "ultralytics.utils.checks.check_yaml", "ultralytics.utils.loss.E2EDetectLoss", "ultralytics.utils.loss.v8ClassificationLoss", "ultralytics.utils.loss.v8DetectionLoss", "ultralytics.utils.loss.v8OBBLoss", "ultralytics.utils.loss.v8PoseLoss", "ultralytics.utils.loss.v8SegmentationLoss", "ultralytics.utils.ops.make_divisible", "ultralytics.utils.patches.torch_load", "ultralytics.utils.plotting.feature_visualization", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "ultralytics.utils.torch_utils.fuse_deconv_and_bn", "ultralytics.utils.torch_utils.initialize_weights", "ultralytics.utils.torch_utils.intersect_dicts", "ultralytics.utils.torch_utils.model_info", "ultralytics.utils.torch_utils.scale_img", "ultralytics.utils.torch_utils.smart_inference_mode", "ultralytics.utils.torch_utils.time_sync", "sys", "importlib.import_module", "ultralytics.utils.downloads.attempt_download_asset", "ast", "ultralytics.models.utils.loss.RTDETRDetectionLoss", "ultralytics.nn.text_model.build_text_model", "ultralytics.nn.text_model.build_text_model", "thop", "ultralytics.utils.loss.TVPDetectLoss", "ultralytics.utils.loss.TVPSegmentLoss", "torch.nn.ModuleList" ], "chunk_id": "class_Ensemble_cc07f57b" }, { "content": "def temporary_modules(modules=None, attributes=None):\n \"\"\"\n Context manager for temporarily adding or modifying modules in Python's module cache (`sys.modules`).\n\n This function can be used to change the module paths during runtime. It's useful when refactoring code,\n where you've moved a module from one location to another, but you still want to support the old import\n paths for backwards compatibility.\n\n Args:\n modules (dict, optional): A dictionary mapping old module paths to new module paths.\n attributes (dict, optional): A dictionary mapping old module attributes to new module attributes.\n\n Examples:\n >>> with temporary_modules({\"old.module\": \"new.module\"}, {\"old.module.attribute\": \"new.module.attribute\"}):\n >>> import old.module # this will now import new.module\n >>> from old.module import attribute # this will now import new.module.attribute\n\n Note:\n The changes are only in effect inside the context manager and are undone once the context manager exits.\n Be aware that directly manipulating `sys.modules` can lead to unpredictable results, especially in larger\n applications or libraries. Use this function with caution.\n \"\"\"\n if modules is None:\n modules = {}\n if attributes is None:\n attributes = {}\n import sys\n from importlib import import_module\n\n try:\n # Set attributes in sys.modules under their old name\n for old, new in attributes.items():\n old_module, old_attr = old.rsplit(\".\", 1)\n new_module, new_attr = new.rsplit(\".\", 1)\n setattr(import_module(old_module), old_attr, getattr(import_module(new_module), new_attr))\n\n # Set modules in sys.modules under their old name\n for old, new in modules.items():\n sys.modules[old] = import_module(new)\n\n yield\n finally:\n # Remove the temporary module paths\n for old in modules:\n if old in sys.modules:\n del sys.modules[old]", "chunk_type": "function", "name": "temporary_modules", "file_path": "ultralytics\\ultralytics\\nn\\tasks.py", "start_line": 1317, "end_line": 1362, "start_col": 0, "end_col": 36, "parent_name": null, "docstring": "Context manager for temporarily adding or modifying modules in Python's module cache (`sys.modules`).\n\nThis function can be used to change the module paths during runtime. It's useful when refactoring code,\nwhere you've moved a module from one location to another, but you still want to support the old import\npaths for backwards compatibility.\n\nArgs:\n modules (dict, optional): A dictionary mapping old module paths to new module paths.\n attributes (dict, optional): A dictionary mapping old module attributes to new module attributes.\n\nExamples:\n >>> with temporary_modules({\"old.module\": \"new.module\"}, {\"old.module.attribute\": \"new.module.attribute\"}):\n >>> import old.module # this will now import new.module\n >>> from old.module import attribute # this will now import new.module.attribute\n\nNote:\n The changes are only in effect inside the context manager and are undone once the context manager exits.\n Be aware that directly manipulating `sys.modules` can lead to unpredictable results, especially in larger\n applications or libraries. Use this function with caution.", "parameters": [ "modules", "attributes" ], "return_type": null, "decorators": [ "contextlib.contextmanager" ], "complexity_score": 7, "dependencies": [ "contextlib", "pickle", "re", "types", "copy.deepcopy", "pathlib.Path", "torch", "torch.nn", "ultralytics.nn.autobackend.check_class_names", "ultralytics.nn.modules.AIFI", "ultralytics.nn.modules.C1", "ultralytics.nn.modules.C2", "ultralytics.nn.modules.C2PSA", "ultralytics.nn.modules.C3", "ultralytics.nn.modules.C3TR", "ultralytics.nn.modules.ELAN1", "ultralytics.nn.modules.OBB", "ultralytics.nn.modules.PSA", "ultralytics.nn.modules.SPP", "ultralytics.nn.modules.SPPELAN", "ultralytics.nn.modules.SPPF", "ultralytics.nn.modules.A2C2f", "ultralytics.nn.modules.AConv", "ultralytics.nn.modules.ADown", "ultralytics.nn.modules.Bottleneck", "ultralytics.nn.modules.BottleneckCSP", "ultralytics.nn.modules.C2f", "ultralytics.nn.modules.C2fAttn", "ultralytics.nn.modules.C2fCIB", "ultralytics.nn.modules.C2fPSA", "ultralytics.nn.modules.C3Ghost", "ultralytics.nn.modules.C3k2", "ultralytics.nn.modules.C3x", "ultralytics.nn.modules.CBFuse", "ultralytics.nn.modules.CBLinear", "ultralytics.nn.modules.Classify", "ultralytics.nn.modules.Concat", "ultralytics.nn.modules.Conv", "ultralytics.nn.modules.Conv2", "ultralytics.nn.modules.ConvTranspose", "ultralytics.nn.modules.Detect", "ultralytics.nn.modules.DWConv", "ultralytics.nn.modules.DWConvTranspose2d", "ultralytics.nn.modules.Focus", "ultralytics.nn.modules.GhostBottleneck", "ultralytics.nn.modules.GhostConv", "ultralytics.nn.modules.HGBlock", "ultralytics.nn.modules.HGStem", "ultralytics.nn.modules.ImagePoolingAttn", "ultralytics.nn.modules.Index", "ultralytics.nn.modules.LRPCHead", "ultralytics.nn.modules.Pose", "ultralytics.nn.modules.RepC3", "ultralytics.nn.modules.RepConv", "ultralytics.nn.modules.RepNCSPELAN4", "ultralytics.nn.modules.RepVGGDW", "ultralytics.nn.modules.ResNetLayer", "ultralytics.nn.modules.RTDETRDecoder", "ultralytics.nn.modules.SCDown", "ultralytics.nn.modules.Segment", "ultralytics.nn.modules.TorchVision", "ultralytics.nn.modules.WorldDetect", "ultralytics.nn.modules.YOLOEDetect", "ultralytics.nn.modules.YOLOESegment", "ultralytics.nn.modules.v10Detect", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.DEFAULT_CFG_KEYS", "ultralytics.utils.LOGGER", "ultralytics.utils.YAML", "ultralytics.utils.colorstr", "ultralytics.utils.emojis", "ultralytics.utils.checks.check_requirements", "ultralytics.utils.checks.check_suffix", "ultralytics.utils.checks.check_yaml", "ultralytics.utils.loss.E2EDetectLoss", "ultralytics.utils.loss.v8ClassificationLoss", "ultralytics.utils.loss.v8DetectionLoss", "ultralytics.utils.loss.v8OBBLoss", "ultralytics.utils.loss.v8PoseLoss", "ultralytics.utils.loss.v8SegmentationLoss", "ultralytics.utils.ops.make_divisible", "ultralytics.utils.patches.torch_load", "ultralytics.utils.plotting.feature_visualization", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "ultralytics.utils.torch_utils.fuse_deconv_and_bn", "ultralytics.utils.torch_utils.initialize_weights", "ultralytics.utils.torch_utils.intersect_dicts", "ultralytics.utils.torch_utils.model_info", "ultralytics.utils.torch_utils.scale_img", "ultralytics.utils.torch_utils.smart_inference_mode", "ultralytics.utils.torch_utils.time_sync", "sys", "importlib.import_module", "ultralytics.utils.downloads.attempt_download_asset", "ast", "ultralytics.models.utils.loss.RTDETRDetectionLoss", "ultralytics.nn.text_model.build_text_model", "ultralytics.nn.text_model.build_text_model", "thop", "ultralytics.utils.loss.TVPDetectLoss", "ultralytics.utils.loss.TVPSegmentLoss" ], "chunk_id": "function_temporary_modules_a0ebf6f9" }, { "content": "class SafeClass:\n \"\"\"A placeholder class to replace unknown classes during unpickling.\"\"\"\n\n def __init__(self, *args, **kwargs):\n \"\"\"Initialize SafeClass instance, ignoring all arguments.\"\"\"\n pass\n\n def __call__(self, *args, **kwargs):\n \"\"\"Run SafeClass instance, ignoring all arguments.\"\"\"\n pass", "chunk_type": "class", "name": "SafeClass", "file_path": "ultralytics\\ultralytics\\nn\\tasks.py", "start_line": 1365, "end_line": 1374, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": "A placeholder class to replace unknown classes during unpickling.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "contextlib", "pickle", "re", "types", "copy.deepcopy", "pathlib.Path", "torch", "torch.nn", "ultralytics.nn.autobackend.check_class_names", "ultralytics.nn.modules.AIFI", "ultralytics.nn.modules.C1", "ultralytics.nn.modules.C2", "ultralytics.nn.modules.C2PSA", "ultralytics.nn.modules.C3", "ultralytics.nn.modules.C3TR", "ultralytics.nn.modules.ELAN1", "ultralytics.nn.modules.OBB", "ultralytics.nn.modules.PSA", "ultralytics.nn.modules.SPP", "ultralytics.nn.modules.SPPELAN", "ultralytics.nn.modules.SPPF", "ultralytics.nn.modules.A2C2f", "ultralytics.nn.modules.AConv", "ultralytics.nn.modules.ADown", "ultralytics.nn.modules.Bottleneck", "ultralytics.nn.modules.BottleneckCSP", "ultralytics.nn.modules.C2f", "ultralytics.nn.modules.C2fAttn", "ultralytics.nn.modules.C2fCIB", "ultralytics.nn.modules.C2fPSA", "ultralytics.nn.modules.C3Ghost", "ultralytics.nn.modules.C3k2", "ultralytics.nn.modules.C3x", "ultralytics.nn.modules.CBFuse", "ultralytics.nn.modules.CBLinear", "ultralytics.nn.modules.Classify", "ultralytics.nn.modules.Concat", "ultralytics.nn.modules.Conv", "ultralytics.nn.modules.Conv2", "ultralytics.nn.modules.ConvTranspose", "ultralytics.nn.modules.Detect", "ultralytics.nn.modules.DWConv", "ultralytics.nn.modules.DWConvTranspose2d", "ultralytics.nn.modules.Focus", "ultralytics.nn.modules.GhostBottleneck", "ultralytics.nn.modules.GhostConv", "ultralytics.nn.modules.HGBlock", "ultralytics.nn.modules.HGStem", "ultralytics.nn.modules.ImagePoolingAttn", "ultralytics.nn.modules.Index", "ultralytics.nn.modules.LRPCHead", "ultralytics.nn.modules.Pose", "ultralytics.nn.modules.RepC3", "ultralytics.nn.modules.RepConv", "ultralytics.nn.modules.RepNCSPELAN4", "ultralytics.nn.modules.RepVGGDW", "ultralytics.nn.modules.ResNetLayer", "ultralytics.nn.modules.RTDETRDecoder", "ultralytics.nn.modules.SCDown", "ultralytics.nn.modules.Segment", "ultralytics.nn.modules.TorchVision", "ultralytics.nn.modules.WorldDetect", "ultralytics.nn.modules.YOLOEDetect", "ultralytics.nn.modules.YOLOESegment", "ultralytics.nn.modules.v10Detect", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.DEFAULT_CFG_KEYS", "ultralytics.utils.LOGGER", "ultralytics.utils.YAML", "ultralytics.utils.colorstr", "ultralytics.utils.emojis", "ultralytics.utils.checks.check_requirements", "ultralytics.utils.checks.check_suffix", "ultralytics.utils.checks.check_yaml", "ultralytics.utils.loss.E2EDetectLoss", "ultralytics.utils.loss.v8ClassificationLoss", "ultralytics.utils.loss.v8DetectionLoss", "ultralytics.utils.loss.v8OBBLoss", "ultralytics.utils.loss.v8PoseLoss", "ultralytics.utils.loss.v8SegmentationLoss", "ultralytics.utils.ops.make_divisible", "ultralytics.utils.patches.torch_load", "ultralytics.utils.plotting.feature_visualization", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "ultralytics.utils.torch_utils.fuse_deconv_and_bn", "ultralytics.utils.torch_utils.initialize_weights", "ultralytics.utils.torch_utils.intersect_dicts", "ultralytics.utils.torch_utils.model_info", "ultralytics.utils.torch_utils.scale_img", "ultralytics.utils.torch_utils.smart_inference_mode", "ultralytics.utils.torch_utils.time_sync", "sys", "importlib.import_module", "ultralytics.utils.downloads.attempt_download_asset", "ast", "ultralytics.models.utils.loss.RTDETRDetectionLoss", "ultralytics.nn.text_model.build_text_model", "ultralytics.nn.text_model.build_text_model", "thop", "ultralytics.utils.loss.TVPDetectLoss", "ultralytics.utils.loss.TVPSegmentLoss" ], "chunk_id": "class_SafeClass_75518994" }, { "content": "class SafeUnpickler(pickle.Unpickler):\n \"\"\"Custom Unpickler that replaces unknown classes with SafeClass.\"\"\"\n\n def find_class(self, module, name):\n \"\"\"\n Attempt to find a class, returning SafeClass if not among safe modules.\n\n Args:\n module (str): Module name.\n name (str): Class name.\n\n Returns:\n (type): Found class or SafeClass.\n \"\"\"\n safe_modules = (\n \"torch\",\n \"collections\",\n \"collections.abc\",\n \"builtins\",\n \"math\",\n \"numpy\",\n # Add other modules considered safe\n )\n if module in safe_modules:\n return super().find_class(module, name)\n else:\n return SafeClass", "chunk_type": "class", "name": "SafeUnpickler", "file_path": "ultralytics\\ultralytics\\nn\\tasks.py", "start_line": 1377, "end_line": 1403, "start_col": 0, "end_col": 28, "parent_name": null, "docstring": "Custom Unpickler that replaces unknown classes with SafeClass.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "contextlib", "pickle", "re", "types", "copy.deepcopy", "pathlib.Path", "torch", "torch.nn", "ultralytics.nn.autobackend.check_class_names", "ultralytics.nn.modules.AIFI", "ultralytics.nn.modules.C1", "ultralytics.nn.modules.C2", "ultralytics.nn.modules.C2PSA", "ultralytics.nn.modules.C3", "ultralytics.nn.modules.C3TR", "ultralytics.nn.modules.ELAN1", "ultralytics.nn.modules.OBB", "ultralytics.nn.modules.PSA", "ultralytics.nn.modules.SPP", "ultralytics.nn.modules.SPPELAN", "ultralytics.nn.modules.SPPF", "ultralytics.nn.modules.A2C2f", "ultralytics.nn.modules.AConv", "ultralytics.nn.modules.ADown", "ultralytics.nn.modules.Bottleneck", "ultralytics.nn.modules.BottleneckCSP", "ultralytics.nn.modules.C2f", "ultralytics.nn.modules.C2fAttn", "ultralytics.nn.modules.C2fCIB", "ultralytics.nn.modules.C2fPSA", "ultralytics.nn.modules.C3Ghost", "ultralytics.nn.modules.C3k2", "ultralytics.nn.modules.C3x", "ultralytics.nn.modules.CBFuse", "ultralytics.nn.modules.CBLinear", "ultralytics.nn.modules.Classify", "ultralytics.nn.modules.Concat", "ultralytics.nn.modules.Conv", "ultralytics.nn.modules.Conv2", "ultralytics.nn.modules.ConvTranspose", "ultralytics.nn.modules.Detect", "ultralytics.nn.modules.DWConv", "ultralytics.nn.modules.DWConvTranspose2d", "ultralytics.nn.modules.Focus", "ultralytics.nn.modules.GhostBottleneck", "ultralytics.nn.modules.GhostConv", "ultralytics.nn.modules.HGBlock", "ultralytics.nn.modules.HGStem", "ultralytics.nn.modules.ImagePoolingAttn", "ultralytics.nn.modules.Index", "ultralytics.nn.modules.LRPCHead", "ultralytics.nn.modules.Pose", "ultralytics.nn.modules.RepC3", "ultralytics.nn.modules.RepConv", "ultralytics.nn.modules.RepNCSPELAN4", "ultralytics.nn.modules.RepVGGDW", "ultralytics.nn.modules.ResNetLayer", "ultralytics.nn.modules.RTDETRDecoder", "ultralytics.nn.modules.SCDown", "ultralytics.nn.modules.Segment", "ultralytics.nn.modules.TorchVision", "ultralytics.nn.modules.WorldDetect", "ultralytics.nn.modules.YOLOEDetect", "ultralytics.nn.modules.YOLOESegment", "ultralytics.nn.modules.v10Detect", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.DEFAULT_CFG_KEYS", "ultralytics.utils.LOGGER", "ultralytics.utils.YAML", "ultralytics.utils.colorstr", "ultralytics.utils.emojis", "ultralytics.utils.checks.check_requirements", "ultralytics.utils.checks.check_suffix", "ultralytics.utils.checks.check_yaml", "ultralytics.utils.loss.E2EDetectLoss", "ultralytics.utils.loss.v8ClassificationLoss", "ultralytics.utils.loss.v8DetectionLoss", "ultralytics.utils.loss.v8OBBLoss", "ultralytics.utils.loss.v8PoseLoss", "ultralytics.utils.loss.v8SegmentationLoss", "ultralytics.utils.ops.make_divisible", "ultralytics.utils.patches.torch_load", "ultralytics.utils.plotting.feature_visualization", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "ultralytics.utils.torch_utils.fuse_deconv_and_bn", "ultralytics.utils.torch_utils.initialize_weights", "ultralytics.utils.torch_utils.intersect_dicts", "ultralytics.utils.torch_utils.model_info", "ultralytics.utils.torch_utils.scale_img", "ultralytics.utils.torch_utils.smart_inference_mode", "ultralytics.utils.torch_utils.time_sync", "sys", "importlib.import_module", "ultralytics.utils.downloads.attempt_download_asset", "ast", "ultralytics.models.utils.loss.RTDETRDetectionLoss", "ultralytics.nn.text_model.build_text_model", "ultralytics.nn.text_model.build_text_model", "thop", "ultralytics.utils.loss.TVPDetectLoss", "ultralytics.utils.loss.TVPSegmentLoss", "pickle.Unpickler" ], "chunk_id": "class_SafeUnpickler_1ef84d7c" }, { "content": "def torch_safe_load(weight, safe_only=False):\n \"\"\"\n Attempt to load a PyTorch model with the torch.load() function. If a ModuleNotFoundError is raised, it catches the\n error, logs a warning message, and attempts to install the missing module via the check_requirements() function.\n After installation, the function again attempts to load the model using torch.load().\n\n Args:\n weight (str): The file path of the PyTorch model.\n safe_only (bool): If True, replace unknown classes with SafeClass during loading.\n\n Returns:\n ckpt (dict): The loaded model checkpoint.\n file (str): The loaded filename.\n\n Examples:\n >>> from ultralytics.nn.tasks import torch_safe_load\n >>> ckpt, file = torch_safe_load(\"path/to/best.pt\", safe_only=True)\n \"\"\"\n from ultralytics.utils.downloads import attempt_download_asset\n\n check_suffix(file=weight, suffix=\".pt\")\n file = attempt_download_asset(weight) # search online if missing locally\n try:\n with temporary_modules(\n modules={\n \"ultralytics.yolo.utils\": \"ultralytics.utils\",\n \"ultralytics.yolo.v8\": \"ultralytics.models.yolo\",\n \"ultralytics.yolo.data\": \"ultralytics.data\",\n },\n attributes={\n \"ultralytics.nn.modules.block.Silence\": \"torch.nn.Identity\", # YOLOv9e\n \"ultralytics.nn.tasks.YOLOv10DetectionModel\": \"ultralytics.nn.tasks.DetectionModel\", # YOLOv10\n \"ultralytics.utils.loss.v10DetectLoss\": \"ultralytics.utils.loss.E2EDetectLoss\", # YOLOv10\n },\n ):\n if safe_only:\n # Load via custom pickle module\n safe_pickle = types.ModuleType(\"safe_pickle\")\n safe_pickle.Unpickler = SafeUnpickler\n safe_pickle.load = lambda file_obj: SafeUnpickler(file_obj).load()\n with open(file, \"rb\") as f:\n ckpt = torch_load(f, pickle_module=safe_pickle)\n else:\n ckpt = torch_load(file, map_location=\"cpu\")\n\n except ModuleNotFoundError as e: # e.name is missing module name\n if e.name == \"models\":\n raise TypeError(\n emojis(\n f\"ERROR ❌️ {weight} appears to be an Ultralytics YOLOv5 model originally trained \"\n f\"with https://github.com/ultralytics/yolov5.\\nThis model is NOT forwards compatible with \"\n f\"YOLOv8 at https://github.com/ultralytics/ultralytics.\"\n f\"\\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to \"\n f\"run a command with an official Ultralytics model, i.e. 'yolo predict model=yolo11n.pt'\"\n )\n ) from e\n elif e.name == \"numpy._core\":\n raise ModuleNotFoundError(\n emojis(\n f\"ERROR ❌️ {weight} requires numpy>=1.26.1, however numpy=={__import__('numpy').__version__} is installed.\"\n )\n ) from e\n LOGGER.warning(\n f\"{weight} appears to require '{e.name}', which is not in Ultralytics requirements.\"\n f\"\\nAutoInstall will run now for '{e.name}' but this feature will be removed in the future.\"\n f\"\\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to \"\n f\"run a command with an official Ultralytics model, i.e. 'yolo predict model=yolo11n.pt'\"\n )\n check_requirements(e.name) # install missing module\n ckpt = torch_load(file, map_location=\"cpu\")\n\n if not isinstance(ckpt, dict):\n # File is likely a YOLO instance saved with i.e. torch.save(model, \"saved_model.pt\")\n LOGGER.warning(\n f\"The file '{weight}' appears to be improperly saved or formatted. \"\n f\"For optimal results, use model.save('filename.pt') to correctly save YOLO models.\"\n )\n ckpt = {\"model\": ckpt.model}\n\n return ckpt, file", "chunk_type": "function", "name": "torch_safe_load", "file_path": "ultralytics\\ultralytics\\nn\\tasks.py", "start_line": 1406, "end_line": 1485, "start_col": 0, "end_col": 21, "parent_name": null, "docstring": "Attempt to load a PyTorch model with the torch.load() function. If a ModuleNotFoundError is raised, it catches the\nerror, logs a warning message, and attempts to install the missing module via the check_requirements() function.\nAfter installation, the function again attempts to load the model using torch.load().\n\nArgs:\n weight (str): The file path of the PyTorch model.\n safe_only (bool): If True, replace unknown classes with SafeClass during loading.\n\nReturns:\n ckpt (dict): The loaded model checkpoint.\n file (str): The loaded filename.\n\nExamples:\n >>> from ultralytics.nn.tasks import torch_safe_load\n >>> ckpt, file = torch_safe_load(\"path/to/best.pt\", safe_only=True)", "parameters": [ "weight", "safe_only" ], "return_type": null, "decorators": [], "complexity_score": 6, "dependencies": [ "contextlib", "pickle", "re", "types", "copy.deepcopy", "pathlib.Path", "torch", "torch.nn", "ultralytics.nn.autobackend.check_class_names", "ultralytics.nn.modules.AIFI", "ultralytics.nn.modules.C1", "ultralytics.nn.modules.C2", "ultralytics.nn.modules.C2PSA", "ultralytics.nn.modules.C3", "ultralytics.nn.modules.C3TR", "ultralytics.nn.modules.ELAN1", "ultralytics.nn.modules.OBB", "ultralytics.nn.modules.PSA", "ultralytics.nn.modules.SPP", "ultralytics.nn.modules.SPPELAN", "ultralytics.nn.modules.SPPF", "ultralytics.nn.modules.A2C2f", "ultralytics.nn.modules.AConv", "ultralytics.nn.modules.ADown", "ultralytics.nn.modules.Bottleneck", "ultralytics.nn.modules.BottleneckCSP", "ultralytics.nn.modules.C2f", "ultralytics.nn.modules.C2fAttn", "ultralytics.nn.modules.C2fCIB", "ultralytics.nn.modules.C2fPSA", "ultralytics.nn.modules.C3Ghost", "ultralytics.nn.modules.C3k2", "ultralytics.nn.modules.C3x", "ultralytics.nn.modules.CBFuse", "ultralytics.nn.modules.CBLinear", "ultralytics.nn.modules.Classify", "ultralytics.nn.modules.Concat", "ultralytics.nn.modules.Conv", "ultralytics.nn.modules.Conv2", "ultralytics.nn.modules.ConvTranspose", "ultralytics.nn.modules.Detect", "ultralytics.nn.modules.DWConv", "ultralytics.nn.modules.DWConvTranspose2d", "ultralytics.nn.modules.Focus", "ultralytics.nn.modules.GhostBottleneck", "ultralytics.nn.modules.GhostConv", "ultralytics.nn.modules.HGBlock", "ultralytics.nn.modules.HGStem", "ultralytics.nn.modules.ImagePoolingAttn", "ultralytics.nn.modules.Index", "ultralytics.nn.modules.LRPCHead", "ultralytics.nn.modules.Pose", "ultralytics.nn.modules.RepC3", "ultralytics.nn.modules.RepConv", "ultralytics.nn.modules.RepNCSPELAN4", "ultralytics.nn.modules.RepVGGDW", "ultralytics.nn.modules.ResNetLayer", "ultralytics.nn.modules.RTDETRDecoder", "ultralytics.nn.modules.SCDown", "ultralytics.nn.modules.Segment", "ultralytics.nn.modules.TorchVision", "ultralytics.nn.modules.WorldDetect", "ultralytics.nn.modules.YOLOEDetect", "ultralytics.nn.modules.YOLOESegment", "ultralytics.nn.modules.v10Detect", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.DEFAULT_CFG_KEYS", "ultralytics.utils.LOGGER", "ultralytics.utils.YAML", "ultralytics.utils.colorstr", "ultralytics.utils.emojis", "ultralytics.utils.checks.check_requirements", "ultralytics.utils.checks.check_suffix", "ultralytics.utils.checks.check_yaml", "ultralytics.utils.loss.E2EDetectLoss", "ultralytics.utils.loss.v8ClassificationLoss", "ultralytics.utils.loss.v8DetectionLoss", "ultralytics.utils.loss.v8OBBLoss", "ultralytics.utils.loss.v8PoseLoss", "ultralytics.utils.loss.v8SegmentationLoss", "ultralytics.utils.ops.make_divisible", "ultralytics.utils.patches.torch_load", "ultralytics.utils.plotting.feature_visualization", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "ultralytics.utils.torch_utils.fuse_deconv_and_bn", "ultralytics.utils.torch_utils.initialize_weights", "ultralytics.utils.torch_utils.intersect_dicts", "ultralytics.utils.torch_utils.model_info", "ultralytics.utils.torch_utils.scale_img", "ultralytics.utils.torch_utils.smart_inference_mode", "ultralytics.utils.torch_utils.time_sync", "sys", "importlib.import_module", "ultralytics.utils.downloads.attempt_download_asset", "ast", "ultralytics.models.utils.loss.RTDETRDetectionLoss", "ultralytics.nn.text_model.build_text_model", "ultralytics.nn.text_model.build_text_model", "thop", "ultralytics.utils.loss.TVPDetectLoss", "ultralytics.utils.loss.TVPSegmentLoss" ], "chunk_id": "function_torch_safe_load_db4d8b8e" }, { "content": "def attempt_load_weights(weights, device=None, inplace=True, fuse=False):\n \"\"\"\n Load an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a.\n\n Args:\n weights (str | List[str]): Model weights path(s).\n device (torch.device, optional): Device to load model to.\n inplace (bool): Whether to do inplace operations.\n fuse (bool): Whether to fuse model.\n\n Returns:\n (torch.nn.Module): Loaded model.\n \"\"\"\n ensemble = Ensemble()\n for w in weights if isinstance(weights, list) else [weights]:\n ckpt, w = torch_safe_load(w) # load ckpt\n args = {**DEFAULT_CFG_DICT, **ckpt[\"train_args\"]} if \"train_args\" in ckpt else None # combined args\n model = (ckpt.get(\"ema\") or ckpt[\"model\"]).to(device).float() # FP32 model\n\n # Model compatibility updates\n model.args = args # attach args to model\n model.pt_path = w # attach *.pt file path to model\n model.task = getattr(model, \"task\", guess_model_task(model))\n if not hasattr(model, \"stride\"):\n model.stride = torch.tensor([32.0])\n\n # Append\n ensemble.append(model.fuse().eval() if fuse and hasattr(model, \"fuse\") else model.eval()) # model in eval mode\n\n # Module updates\n for m in ensemble.modules():\n if hasattr(m, \"inplace\"):\n m.inplace = inplace\n elif isinstance(m, torch.nn.Upsample) and not hasattr(m, \"recompute_scale_factor\"):\n m.recompute_scale_factor = None # torch 1.11.0 compatibility\n\n # Return model\n if len(ensemble) == 1:\n return ensemble[-1]\n\n # Return ensemble\n LOGGER.info(f\"Ensemble created with {weights}\\n\")\n for k in \"names\", \"nc\", \"yaml\":\n setattr(ensemble, k, getattr(ensemble[0], k))\n ensemble.stride = ensemble[int(torch.argmax(torch.tensor([m.stride.max() for m in ensemble])))].stride\n assert all(ensemble[0].nc == m.nc for m in ensemble), f\"Models differ in class counts {[m.nc for m in ensemble]}\"\n return ensemble", "chunk_type": "function", "name": "attempt_load_weights", "file_path": "ultralytics\\ultralytics\\nn\\tasks.py", "start_line": 1488, "end_line": 1534, "start_col": 0, "end_col": 19, "parent_name": null, "docstring": "Load an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a.\n\nArgs:\n weights (str | List[str]): Model weights path(s).\n device (torch.device, optional): Device to load model to.\n inplace (bool): Whether to do inplace operations.\n fuse (bool): Whether to fuse model.\n\nReturns:\n (torch.nn.Module): Loaded model.", "parameters": [ "weights", "device", "inplace", "fuse" ], "return_type": null, "decorators": [], "complexity_score": 11, "dependencies": [ "contextlib", "pickle", "re", "types", "copy.deepcopy", "pathlib.Path", "torch", "torch.nn", "ultralytics.nn.autobackend.check_class_names", "ultralytics.nn.modules.AIFI", "ultralytics.nn.modules.C1", "ultralytics.nn.modules.C2", "ultralytics.nn.modules.C2PSA", "ultralytics.nn.modules.C3", "ultralytics.nn.modules.C3TR", "ultralytics.nn.modules.ELAN1", "ultralytics.nn.modules.OBB", "ultralytics.nn.modules.PSA", "ultralytics.nn.modules.SPP", "ultralytics.nn.modules.SPPELAN", "ultralytics.nn.modules.SPPF", "ultralytics.nn.modules.A2C2f", "ultralytics.nn.modules.AConv", "ultralytics.nn.modules.ADown", "ultralytics.nn.modules.Bottleneck", "ultralytics.nn.modules.BottleneckCSP", "ultralytics.nn.modules.C2f", "ultralytics.nn.modules.C2fAttn", "ultralytics.nn.modules.C2fCIB", "ultralytics.nn.modules.C2fPSA", "ultralytics.nn.modules.C3Ghost", "ultralytics.nn.modules.C3k2", "ultralytics.nn.modules.C3x", "ultralytics.nn.modules.CBFuse", "ultralytics.nn.modules.CBLinear", "ultralytics.nn.modules.Classify", "ultralytics.nn.modules.Concat", "ultralytics.nn.modules.Conv", "ultralytics.nn.modules.Conv2", "ultralytics.nn.modules.ConvTranspose", "ultralytics.nn.modules.Detect", "ultralytics.nn.modules.DWConv", "ultralytics.nn.modules.DWConvTranspose2d", "ultralytics.nn.modules.Focus", "ultralytics.nn.modules.GhostBottleneck", "ultralytics.nn.modules.GhostConv", "ultralytics.nn.modules.HGBlock", "ultralytics.nn.modules.HGStem", "ultralytics.nn.modules.ImagePoolingAttn", "ultralytics.nn.modules.Index", "ultralytics.nn.modules.LRPCHead", "ultralytics.nn.modules.Pose", "ultralytics.nn.modules.RepC3", "ultralytics.nn.modules.RepConv", "ultralytics.nn.modules.RepNCSPELAN4", "ultralytics.nn.modules.RepVGGDW", "ultralytics.nn.modules.ResNetLayer", "ultralytics.nn.modules.RTDETRDecoder", "ultralytics.nn.modules.SCDown", "ultralytics.nn.modules.Segment", "ultralytics.nn.modules.TorchVision", "ultralytics.nn.modules.WorldDetect", "ultralytics.nn.modules.YOLOEDetect", "ultralytics.nn.modules.YOLOESegment", "ultralytics.nn.modules.v10Detect", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.DEFAULT_CFG_KEYS", "ultralytics.utils.LOGGER", "ultralytics.utils.YAML", "ultralytics.utils.colorstr", "ultralytics.utils.emojis", "ultralytics.utils.checks.check_requirements", "ultralytics.utils.checks.check_suffix", "ultralytics.utils.checks.check_yaml", "ultralytics.utils.loss.E2EDetectLoss", "ultralytics.utils.loss.v8ClassificationLoss", "ultralytics.utils.loss.v8DetectionLoss", "ultralytics.utils.loss.v8OBBLoss", "ultralytics.utils.loss.v8PoseLoss", "ultralytics.utils.loss.v8SegmentationLoss", "ultralytics.utils.ops.make_divisible", "ultralytics.utils.patches.torch_load", "ultralytics.utils.plotting.feature_visualization", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "ultralytics.utils.torch_utils.fuse_deconv_and_bn", "ultralytics.utils.torch_utils.initialize_weights", "ultralytics.utils.torch_utils.intersect_dicts", "ultralytics.utils.torch_utils.model_info", "ultralytics.utils.torch_utils.scale_img", "ultralytics.utils.torch_utils.smart_inference_mode", "ultralytics.utils.torch_utils.time_sync", "sys", "importlib.import_module", "ultralytics.utils.downloads.attempt_download_asset", "ast", "ultralytics.models.utils.loss.RTDETRDetectionLoss", "ultralytics.nn.text_model.build_text_model", "ultralytics.nn.text_model.build_text_model", "thop", "ultralytics.utils.loss.TVPDetectLoss", "ultralytics.utils.loss.TVPSegmentLoss" ], "chunk_id": "function_attempt_load_weights_3e642728" }, { "content": "def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):\n \"\"\"\n Load a single model weights.\n\n Args:\n weight (str): Model weight path.\n device (torch.device, optional): Device to load model to.\n inplace (bool): Whether to do inplace operations.\n fuse (bool): Whether to fuse model.\n\n Returns:\n model (torch.nn.Module): Loaded model.\n ckpt (dict): Model checkpoint dictionary.\n \"\"\"\n ckpt, weight = torch_safe_load(weight) # load ckpt\n args = {**DEFAULT_CFG_DICT, **(ckpt.get(\"train_args\", {}))} # combine model and default args, preferring model args\n model = (ckpt.get(\"ema\") or ckpt[\"model\"]).to(device).float() # FP32 model\n\n # Model compatibility updates\n model.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # attach args to model\n model.pt_path = weight # attach *.pt file path to model\n model.task = getattr(model, \"task\", guess_model_task(model))\n if not hasattr(model, \"stride\"):\n model.stride = torch.tensor([32.0])\n\n model = model.fuse().eval() if fuse and hasattr(model, \"fuse\") else model.eval() # model in eval mode\n\n # Module updates\n for m in model.modules():\n if hasattr(m, \"inplace\"):\n m.inplace = inplace\n elif isinstance(m, torch.nn.Upsample) and not hasattr(m, \"recompute_scale_factor\"):\n m.recompute_scale_factor = None # torch 1.11.0 compatibility\n\n # Return model and ckpt\n return model, ckpt", "chunk_type": "function", "name": "attempt_load_one_weight", "file_path": "ultralytics\\ultralytics\\nn\\tasks.py", "start_line": 1537, "end_line": 1572, "start_col": 0, "end_col": 22, "parent_name": null, "docstring": "Load a single model weights.\n\nArgs:\n weight (str): Model weight path.\n device (torch.device, optional): Device to load model to.\n inplace (bool): Whether to do inplace operations.\n fuse (bool): Whether to fuse model.\n\nReturns:\n model (torch.nn.Module): Loaded model.\n ckpt (dict): Model checkpoint dictionary.", "parameters": [ "weight", "device", "inplace", "fuse" ], "return_type": null, "decorators": [], "complexity_score": 6, "dependencies": [ "contextlib", "pickle", "re", "types", "copy.deepcopy", "pathlib.Path", "torch", "torch.nn", "ultralytics.nn.autobackend.check_class_names", "ultralytics.nn.modules.AIFI", "ultralytics.nn.modules.C1", "ultralytics.nn.modules.C2", "ultralytics.nn.modules.C2PSA", "ultralytics.nn.modules.C3", "ultralytics.nn.modules.C3TR", "ultralytics.nn.modules.ELAN1", "ultralytics.nn.modules.OBB", "ultralytics.nn.modules.PSA", "ultralytics.nn.modules.SPP", "ultralytics.nn.modules.SPPELAN", "ultralytics.nn.modules.SPPF", "ultralytics.nn.modules.A2C2f", "ultralytics.nn.modules.AConv", "ultralytics.nn.modules.ADown", "ultralytics.nn.modules.Bottleneck", "ultralytics.nn.modules.BottleneckCSP", "ultralytics.nn.modules.C2f", "ultralytics.nn.modules.C2fAttn", "ultralytics.nn.modules.C2fCIB", "ultralytics.nn.modules.C2fPSA", "ultralytics.nn.modules.C3Ghost", "ultralytics.nn.modules.C3k2", "ultralytics.nn.modules.C3x", "ultralytics.nn.modules.CBFuse", "ultralytics.nn.modules.CBLinear", "ultralytics.nn.modules.Classify", "ultralytics.nn.modules.Concat", "ultralytics.nn.modules.Conv", "ultralytics.nn.modules.Conv2", "ultralytics.nn.modules.ConvTranspose", "ultralytics.nn.modules.Detect", "ultralytics.nn.modules.DWConv", "ultralytics.nn.modules.DWConvTranspose2d", "ultralytics.nn.modules.Focus", "ultralytics.nn.modules.GhostBottleneck", "ultralytics.nn.modules.GhostConv", "ultralytics.nn.modules.HGBlock", "ultralytics.nn.modules.HGStem", "ultralytics.nn.modules.ImagePoolingAttn", "ultralytics.nn.modules.Index", "ultralytics.nn.modules.LRPCHead", "ultralytics.nn.modules.Pose", "ultralytics.nn.modules.RepC3", "ultralytics.nn.modules.RepConv", "ultralytics.nn.modules.RepNCSPELAN4", "ultralytics.nn.modules.RepVGGDW", "ultralytics.nn.modules.ResNetLayer", "ultralytics.nn.modules.RTDETRDecoder", "ultralytics.nn.modules.SCDown", "ultralytics.nn.modules.Segment", "ultralytics.nn.modules.TorchVision", "ultralytics.nn.modules.WorldDetect", "ultralytics.nn.modules.YOLOEDetect", "ultralytics.nn.modules.YOLOESegment", "ultralytics.nn.modules.v10Detect", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.DEFAULT_CFG_KEYS", "ultralytics.utils.LOGGER", "ultralytics.utils.YAML", "ultralytics.utils.colorstr", "ultralytics.utils.emojis", "ultralytics.utils.checks.check_requirements", "ultralytics.utils.checks.check_suffix", "ultralytics.utils.checks.check_yaml", "ultralytics.utils.loss.E2EDetectLoss", "ultralytics.utils.loss.v8ClassificationLoss", "ultralytics.utils.loss.v8DetectionLoss", "ultralytics.utils.loss.v8OBBLoss", "ultralytics.utils.loss.v8PoseLoss", "ultralytics.utils.loss.v8SegmentationLoss", "ultralytics.utils.ops.make_divisible", "ultralytics.utils.patches.torch_load", "ultralytics.utils.plotting.feature_visualization", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "ultralytics.utils.torch_utils.fuse_deconv_and_bn", "ultralytics.utils.torch_utils.initialize_weights", "ultralytics.utils.torch_utils.intersect_dicts", "ultralytics.utils.torch_utils.model_info", "ultralytics.utils.torch_utils.scale_img", "ultralytics.utils.torch_utils.smart_inference_mode", "ultralytics.utils.torch_utils.time_sync", "sys", "importlib.import_module", "ultralytics.utils.downloads.attempt_download_asset", "ast", "ultralytics.models.utils.loss.RTDETRDetectionLoss", "ultralytics.nn.text_model.build_text_model", "ultralytics.nn.text_model.build_text_model", "thop", "ultralytics.utils.loss.TVPDetectLoss", "ultralytics.utils.loss.TVPSegmentLoss" ], "chunk_id": "function_attempt_load_one_weight_aea9d2c2" }, { "content": "def parse_model(d, ch, verbose=True):\n \"\"\"\n Parse a YOLO model.yaml dictionary into a PyTorch model.\n\n Args:\n d (dict): Model dictionary.\n ch (int): Input channels.\n verbose (bool): Whether to print model details.\n\n Returns:\n model (torch.nn.Sequential): PyTorch model.\n save (list): Sorted list of output layers.\n \"\"\"\n import ast\n\n # Args\n legacy = True # backward compatibility for v3/v5/v8/v9 models\n max_channels = float(\"inf\")\n nc, act, scales = (d.get(x) for x in (\"nc\", \"activation\", \"scales\"))\n depth, width, kpt_shape = (d.get(x, 1.0) for x in (\"depth_multiple\", \"width_multiple\", \"kpt_shape\"))\n if scales:\n scale = d.get(\"scale\")\n if not scale:\n scale = tuple(scales.keys())[0]\n LOGGER.warning(f\"no model scale passed. Assuming scale='{scale}'.\")\n depth, width, max_channels = scales[scale]\n\n if act:\n Conv.default_act = eval(act) # redefine default activation, i.e. Conv.default_act = torch.nn.SiLU()\n if verbose:\n LOGGER.info(f\"{colorstr('activation:')} {act}\") # print\n\n if verbose:\n LOGGER.info(f\"\\n{'':>3}{'from':>20}{'n':>3}{'params':>10} {'module':<45}{'arguments':<30}\")\n ch = [ch]\n layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out\n base_modules = frozenset(\n {\n Classify,\n Conv,\n ConvTranspose,\n GhostConv,\n Bottleneck,\n GhostBottleneck,\n SPP,\n SPPF,\n C2fPSA,\n C2PSA,\n DWConv,\n Focus,\n BottleneckCSP,\n C1,\n C2,\n C2f,\n C3k2,\n RepNCSPELAN4,\n ELAN1,\n ADown,\n AConv,\n SPPELAN,\n C2fAttn,\n C3,\n C3TR,\n C3Ghost,\n torch.nn.ConvTranspose2d,\n DWConvTranspose2d,\n C3x,\n RepC3,\n PSA,\n SCDown,\n C2fCIB,\n A2C2f,\n }\n )\n repeat_modules = frozenset( # modules with 'repeat' arguments\n {\n BottleneckCSP,\n C1,\n C2,\n C2f,\n C3k2,\n C2fAttn,\n C3,\n C3TR,\n C3Ghost,\n C3x,\n RepC3,\n C2fPSA,\n C2fCIB,\n C2PSA,\n A2C2f,\n }\n )\n for i, (f, n, m, args) in enumerate(d[\"backbone\"] + d[\"head\"]): # from, number, module, args\n m = (\n getattr(torch.nn, m[3:])\n if \"nn.\" in m\n else getattr(__import__(\"torchvision\").ops, m[16:])\n if \"torchvision.ops.\" in m\n else globals()[m]\n ) # get module\n for j, a in enumerate(args):\n if isinstance(a, str):\n with contextlib.suppress(ValueError):\n args[j] = locals()[a] if a in locals() else ast.literal_eval(a)\n n = n_ = max(round(n * depth), 1) if n > 1 else n # depth gain\n if m in base_modules:\n c1, c2 = ch[f], args[0]\n if c2 != nc: # if c2 not equal to number of classes (i.e. for Classify() output)\n c2 = make_divisible(min(c2, max_channels) * width, 8)\n if m is C2fAttn: # set 1) embed channels and 2) num heads\n args[1] = make_divisible(min(args[1], max_channels // 2) * width, 8)\n args[2] = int(max(round(min(args[2], max_channels // 2 // 32)) * width, 1) if args[2] > 1 else args[2])\n\n args = [c1, c2, *args[1:]]\n if m in repeat_modules:\n args.insert(2, n) # number of repeats\n n = 1\n if m is C3k2: # for M/L/X sizes\n legacy = False\n if scale in \"mlx\":\n args[3] = True\n if m is A2C2f:\n legacy = False\n if scale in \"lx\": # for L/X sizes\n args.extend((True, 1.2))\n if m is C2fCIB:\n legacy = False\n elif m is AIFI:\n args = [ch[f], *args]\n elif m in frozenset({HGStem, HGBlock}):\n c1, cm, c2 = ch[f], args[0], args[1]\n args = [c1, cm, c2, *args[2:]]\n if m is HGBlock:\n args.insert(4, n) # number of repeats\n n = 1\n elif m is ResNetLayer:\n c2 = args[1] if args[3] else args[1] * 4\n elif m is torch.nn.BatchNorm2d:\n args = [ch[f]]\n elif m is Concat:\n c2 = sum(ch[x] for x in f)\n elif m in frozenset(\n {Detect, WorldDetect, YOLOEDetect, Segment, YOLOESegment, Pose, OBB, ImagePoolingAttn, v10Detect}\n ):\n args.append([ch[x] for x in f])\n if m is Segment or m is YOLOESegment:\n args[2] = make_divisible(min(args[2], max_channels) * width, 8)\n if m in {Detect, YOLOEDetect, Segment, YOLOESegment, Pose, OBB}:\n m.legacy = legacy\n elif m is RTDETRDecoder: # special case, channels arg must be passed in index 1\n args.insert(1, [ch[x] for x in f])\n elif m is CBLinear:\n c2 = args[0]\n c1 = ch[f]\n args = [c1, c2, *args[1:]]\n elif m is CBFuse:\n c2 = ch[f[-1]]\n elif m in frozenset({TorchVision, Index}):\n c2 = args[0]\n c1 = ch[f]\n args = [*args[1:]]\n else:\n c2 = ch[f]\n\n m_ = torch.nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module\n t = str(m)[8:-2].replace(\"__main__.\", \"\") # module type\n m_.np = sum(x.numel() for x in m_.parameters()) # number params\n m_.i, m_.f, m_.type = i, f, t # attach index, 'from' index, type\n if verbose:\n LOGGER.info(f\"{i:>3}{str(f):>20}{n_:>3}{m_.np:10.0f} {t:<45}{str(args):<30}\") # print\n save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist\n layers.append(m_)\n if i == 0:\n ch = []\n ch.append(c2)\n return torch.nn.Sequential(*layers), sorted(save)", "chunk_type": "function", "name": "parse_model", "file_path": "ultralytics\\ultralytics\\nn\\tasks.py", "start_line": 1575, "end_line": 1751, "start_col": 0, "end_col": 53, "parent_name": null, "docstring": "Parse a YOLO model.yaml dictionary into a PyTorch model.\n\nArgs:\n d (dict): Model dictionary.\n ch (int): Input channels.\n verbose (bool): Whether to print model details.\n\nReturns:\n model (torch.nn.Sequential): PyTorch model.\n save (list): Sorted list of output layers.", "parameters": [ "d", "ch", "verbose" ], "return_type": null, "decorators": [], "complexity_score": 41, "dependencies": [ "contextlib", "pickle", "re", "types", "copy.deepcopy", "pathlib.Path", "torch", "torch.nn", "ultralytics.nn.autobackend.check_class_names", "ultralytics.nn.modules.AIFI", "ultralytics.nn.modules.C1", "ultralytics.nn.modules.C2", "ultralytics.nn.modules.C2PSA", "ultralytics.nn.modules.C3", "ultralytics.nn.modules.C3TR", "ultralytics.nn.modules.ELAN1", "ultralytics.nn.modules.OBB", "ultralytics.nn.modules.PSA", "ultralytics.nn.modules.SPP", "ultralytics.nn.modules.SPPELAN", "ultralytics.nn.modules.SPPF", "ultralytics.nn.modules.A2C2f", "ultralytics.nn.modules.AConv", "ultralytics.nn.modules.ADown", "ultralytics.nn.modules.Bottleneck", "ultralytics.nn.modules.BottleneckCSP", "ultralytics.nn.modules.C2f", "ultralytics.nn.modules.C2fAttn", "ultralytics.nn.modules.C2fCIB", "ultralytics.nn.modules.C2fPSA", "ultralytics.nn.modules.C3Ghost", "ultralytics.nn.modules.C3k2", "ultralytics.nn.modules.C3x", "ultralytics.nn.modules.CBFuse", "ultralytics.nn.modules.CBLinear", "ultralytics.nn.modules.Classify", "ultralytics.nn.modules.Concat", "ultralytics.nn.modules.Conv", "ultralytics.nn.modules.Conv2", "ultralytics.nn.modules.ConvTranspose", "ultralytics.nn.modules.Detect", "ultralytics.nn.modules.DWConv", "ultralytics.nn.modules.DWConvTranspose2d", "ultralytics.nn.modules.Focus", "ultralytics.nn.modules.GhostBottleneck", "ultralytics.nn.modules.GhostConv", "ultralytics.nn.modules.HGBlock", "ultralytics.nn.modules.HGStem", "ultralytics.nn.modules.ImagePoolingAttn", "ultralytics.nn.modules.Index", "ultralytics.nn.modules.LRPCHead", "ultralytics.nn.modules.Pose", "ultralytics.nn.modules.RepC3", "ultralytics.nn.modules.RepConv", "ultralytics.nn.modules.RepNCSPELAN4", "ultralytics.nn.modules.RepVGGDW", "ultralytics.nn.modules.ResNetLayer", "ultralytics.nn.modules.RTDETRDecoder", "ultralytics.nn.modules.SCDown", "ultralytics.nn.modules.Segment", "ultralytics.nn.modules.TorchVision", "ultralytics.nn.modules.WorldDetect", "ultralytics.nn.modules.YOLOEDetect", "ultralytics.nn.modules.YOLOESegment", "ultralytics.nn.modules.v10Detect", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.DEFAULT_CFG_KEYS", "ultralytics.utils.LOGGER", "ultralytics.utils.YAML", "ultralytics.utils.colorstr", "ultralytics.utils.emojis", "ultralytics.utils.checks.check_requirements", "ultralytics.utils.checks.check_suffix", "ultralytics.utils.checks.check_yaml", "ultralytics.utils.loss.E2EDetectLoss", "ultralytics.utils.loss.v8ClassificationLoss", "ultralytics.utils.loss.v8DetectionLoss", "ultralytics.utils.loss.v8OBBLoss", "ultralytics.utils.loss.v8PoseLoss", "ultralytics.utils.loss.v8SegmentationLoss", "ultralytics.utils.ops.make_divisible", "ultralytics.utils.patches.torch_load", "ultralytics.utils.plotting.feature_visualization", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "ultralytics.utils.torch_utils.fuse_deconv_and_bn", "ultralytics.utils.torch_utils.initialize_weights", "ultralytics.utils.torch_utils.intersect_dicts", "ultralytics.utils.torch_utils.model_info", "ultralytics.utils.torch_utils.scale_img", "ultralytics.utils.torch_utils.smart_inference_mode", "ultralytics.utils.torch_utils.time_sync", "sys", "importlib.import_module", "ultralytics.utils.downloads.attempt_download_asset", "ast", "ultralytics.models.utils.loss.RTDETRDetectionLoss", "ultralytics.nn.text_model.build_text_model", "ultralytics.nn.text_model.build_text_model", "thop", "ultralytics.utils.loss.TVPDetectLoss", "ultralytics.utils.loss.TVPSegmentLoss" ], "chunk_id": "function_parse_model_3429e190" }, { "content": "def yaml_model_load(path):\n \"\"\"\n Load a YOLOv8 model from a YAML file.\n\n Args:\n path (str | Path): Path to the YAML file.\n\n Returns:\n (dict): Model dictionary.\n \"\"\"\n path = Path(path)\n if path.stem in (f\"yolov{d}{x}6\" for x in \"nsmlx\" for d in (5, 8)):\n new_stem = re.sub(r\"(\\d+)([nslmx])6(.+)?$\", r\"\\1\\2-p6\\3\", path.stem)\n LOGGER.warning(f\"Ultralytics YOLO P6 models now use -p6 suffix. Renaming {path.stem} to {new_stem}.\")\n path = path.with_name(new_stem + path.suffix)\n\n unified_path = re.sub(r\"(\\d+)([nslmx])(.+)?$\", r\"\\1\\3\", str(path)) # i.e. yolov8x.yaml -> yolov8.yaml\n yaml_file = check_yaml(unified_path, hard=False) or check_yaml(path)\n d = YAML.load(yaml_file) # model dict\n d[\"scale\"] = guess_model_scale(path)\n d[\"yaml_file\"] = str(path)\n return d", "chunk_type": "function", "name": "yaml_model_load", "file_path": "ultralytics\\ultralytics\\nn\\tasks.py", "start_line": 1754, "end_line": 1775, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": "Load a YOLOv8 model from a YAML file.\n\nArgs:\n path (str | Path): Path to the YAML file.\n\nReturns:\n (dict): Model dictionary.", "parameters": [ "path" ], "return_type": null, "decorators": [], "complexity_score": 3, "dependencies": [ "contextlib", "pickle", "re", "types", "copy.deepcopy", "pathlib.Path", "torch", "torch.nn", "ultralytics.nn.autobackend.check_class_names", "ultralytics.nn.modules.AIFI", "ultralytics.nn.modules.C1", "ultralytics.nn.modules.C2", "ultralytics.nn.modules.C2PSA", "ultralytics.nn.modules.C3", "ultralytics.nn.modules.C3TR", "ultralytics.nn.modules.ELAN1", "ultralytics.nn.modules.OBB", "ultralytics.nn.modules.PSA", "ultralytics.nn.modules.SPP", "ultralytics.nn.modules.SPPELAN", "ultralytics.nn.modules.SPPF", "ultralytics.nn.modules.A2C2f", "ultralytics.nn.modules.AConv", "ultralytics.nn.modules.ADown", "ultralytics.nn.modules.Bottleneck", "ultralytics.nn.modules.BottleneckCSP", "ultralytics.nn.modules.C2f", "ultralytics.nn.modules.C2fAttn", "ultralytics.nn.modules.C2fCIB", "ultralytics.nn.modules.C2fPSA", "ultralytics.nn.modules.C3Ghost", "ultralytics.nn.modules.C3k2", "ultralytics.nn.modules.C3x", "ultralytics.nn.modules.CBFuse", "ultralytics.nn.modules.CBLinear", "ultralytics.nn.modules.Classify", "ultralytics.nn.modules.Concat", "ultralytics.nn.modules.Conv", "ultralytics.nn.modules.Conv2", "ultralytics.nn.modules.ConvTranspose", "ultralytics.nn.modules.Detect", "ultralytics.nn.modules.DWConv", "ultralytics.nn.modules.DWConvTranspose2d", "ultralytics.nn.modules.Focus", "ultralytics.nn.modules.GhostBottleneck", "ultralytics.nn.modules.GhostConv", "ultralytics.nn.modules.HGBlock", "ultralytics.nn.modules.HGStem", "ultralytics.nn.modules.ImagePoolingAttn", "ultralytics.nn.modules.Index", "ultralytics.nn.modules.LRPCHead", "ultralytics.nn.modules.Pose", "ultralytics.nn.modules.RepC3", "ultralytics.nn.modules.RepConv", "ultralytics.nn.modules.RepNCSPELAN4", "ultralytics.nn.modules.RepVGGDW", "ultralytics.nn.modules.ResNetLayer", "ultralytics.nn.modules.RTDETRDecoder", "ultralytics.nn.modules.SCDown", "ultralytics.nn.modules.Segment", "ultralytics.nn.modules.TorchVision", "ultralytics.nn.modules.WorldDetect", "ultralytics.nn.modules.YOLOEDetect", "ultralytics.nn.modules.YOLOESegment", "ultralytics.nn.modules.v10Detect", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.DEFAULT_CFG_KEYS", "ultralytics.utils.LOGGER", "ultralytics.utils.YAML", "ultralytics.utils.colorstr", "ultralytics.utils.emojis", "ultralytics.utils.checks.check_requirements", "ultralytics.utils.checks.check_suffix", "ultralytics.utils.checks.check_yaml", "ultralytics.utils.loss.E2EDetectLoss", "ultralytics.utils.loss.v8ClassificationLoss", "ultralytics.utils.loss.v8DetectionLoss", "ultralytics.utils.loss.v8OBBLoss", "ultralytics.utils.loss.v8PoseLoss", "ultralytics.utils.loss.v8SegmentationLoss", "ultralytics.utils.ops.make_divisible", "ultralytics.utils.patches.torch_load", "ultralytics.utils.plotting.feature_visualization", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "ultralytics.utils.torch_utils.fuse_deconv_and_bn", "ultralytics.utils.torch_utils.initialize_weights", "ultralytics.utils.torch_utils.intersect_dicts", "ultralytics.utils.torch_utils.model_info", "ultralytics.utils.torch_utils.scale_img", "ultralytics.utils.torch_utils.smart_inference_mode", "ultralytics.utils.torch_utils.time_sync", "sys", "importlib.import_module", "ultralytics.utils.downloads.attempt_download_asset", "ast", "ultralytics.models.utils.loss.RTDETRDetectionLoss", "ultralytics.nn.text_model.build_text_model", "ultralytics.nn.text_model.build_text_model", "thop", "ultralytics.utils.loss.TVPDetectLoss", "ultralytics.utils.loss.TVPSegmentLoss" ], "chunk_id": "function_yaml_model_load_375a49f3" }, { "content": "def guess_model_scale(model_path):\n \"\"\"\n Extract the size character n, s, m, l, or x of the model's scale from the model path.\n\n Args:\n model_path (str | Path): The path to the YOLO model's YAML file.\n\n Returns:\n (str): The size character of the model's scale (n, s, m, l, or x).\n \"\"\"\n try:\n return re.search(r\"yolo(e-)?[v]?\\d+([nslmx])\", Path(model_path).stem).group(2) # noqa\n except AttributeError:\n return \"\"", "chunk_type": "function", "name": "guess_model_scale", "file_path": "ultralytics\\ultralytics\\nn\\tasks.py", "start_line": 1778, "end_line": 1791, "start_col": 0, "end_col": 17, "parent_name": null, "docstring": "Extract the size character n, s, m, l, or x of the model's scale from the model path.\n\nArgs:\n model_path (str | Path): The path to the YOLO model's YAML file.\n\nReturns:\n (str): The size character of the model's scale (n, s, m, l, or x).", "parameters": [ "model_path" ], "return_type": null, "decorators": [], "complexity_score": 2, "dependencies": [ "contextlib", "pickle", "re", "types", "copy.deepcopy", "pathlib.Path", "torch", "torch.nn", "ultralytics.nn.autobackend.check_class_names", "ultralytics.nn.modules.AIFI", "ultralytics.nn.modules.C1", "ultralytics.nn.modules.C2", "ultralytics.nn.modules.C2PSA", "ultralytics.nn.modules.C3", "ultralytics.nn.modules.C3TR", "ultralytics.nn.modules.ELAN1", "ultralytics.nn.modules.OBB", "ultralytics.nn.modules.PSA", "ultralytics.nn.modules.SPP", "ultralytics.nn.modules.SPPELAN", "ultralytics.nn.modules.SPPF", "ultralytics.nn.modules.A2C2f", "ultralytics.nn.modules.AConv", "ultralytics.nn.modules.ADown", "ultralytics.nn.modules.Bottleneck", "ultralytics.nn.modules.BottleneckCSP", "ultralytics.nn.modules.C2f", "ultralytics.nn.modules.C2fAttn", "ultralytics.nn.modules.C2fCIB", "ultralytics.nn.modules.C2fPSA", "ultralytics.nn.modules.C3Ghost", "ultralytics.nn.modules.C3k2", "ultralytics.nn.modules.C3x", "ultralytics.nn.modules.CBFuse", "ultralytics.nn.modules.CBLinear", "ultralytics.nn.modules.Classify", "ultralytics.nn.modules.Concat", "ultralytics.nn.modules.Conv", "ultralytics.nn.modules.Conv2", "ultralytics.nn.modules.ConvTranspose", "ultralytics.nn.modules.Detect", "ultralytics.nn.modules.DWConv", "ultralytics.nn.modules.DWConvTranspose2d", "ultralytics.nn.modules.Focus", "ultralytics.nn.modules.GhostBottleneck", "ultralytics.nn.modules.GhostConv", "ultralytics.nn.modules.HGBlock", "ultralytics.nn.modules.HGStem", "ultralytics.nn.modules.ImagePoolingAttn", "ultralytics.nn.modules.Index", "ultralytics.nn.modules.LRPCHead", "ultralytics.nn.modules.Pose", "ultralytics.nn.modules.RepC3", "ultralytics.nn.modules.RepConv", "ultralytics.nn.modules.RepNCSPELAN4", "ultralytics.nn.modules.RepVGGDW", "ultralytics.nn.modules.ResNetLayer", "ultralytics.nn.modules.RTDETRDecoder", "ultralytics.nn.modules.SCDown", "ultralytics.nn.modules.Segment", "ultralytics.nn.modules.TorchVision", "ultralytics.nn.modules.WorldDetect", "ultralytics.nn.modules.YOLOEDetect", "ultralytics.nn.modules.YOLOESegment", "ultralytics.nn.modules.v10Detect", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.DEFAULT_CFG_KEYS", "ultralytics.utils.LOGGER", "ultralytics.utils.YAML", "ultralytics.utils.colorstr", "ultralytics.utils.emojis", "ultralytics.utils.checks.check_requirements", "ultralytics.utils.checks.check_suffix", "ultralytics.utils.checks.check_yaml", "ultralytics.utils.loss.E2EDetectLoss", "ultralytics.utils.loss.v8ClassificationLoss", "ultralytics.utils.loss.v8DetectionLoss", "ultralytics.utils.loss.v8OBBLoss", "ultralytics.utils.loss.v8PoseLoss", "ultralytics.utils.loss.v8SegmentationLoss", "ultralytics.utils.ops.make_divisible", "ultralytics.utils.patches.torch_load", "ultralytics.utils.plotting.feature_visualization", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "ultralytics.utils.torch_utils.fuse_deconv_and_bn", "ultralytics.utils.torch_utils.initialize_weights", "ultralytics.utils.torch_utils.intersect_dicts", "ultralytics.utils.torch_utils.model_info", "ultralytics.utils.torch_utils.scale_img", "ultralytics.utils.torch_utils.smart_inference_mode", "ultralytics.utils.torch_utils.time_sync", "sys", "importlib.import_module", "ultralytics.utils.downloads.attempt_download_asset", "ast", "ultralytics.models.utils.loss.RTDETRDetectionLoss", "ultralytics.nn.text_model.build_text_model", "ultralytics.nn.text_model.build_text_model", "thop", "ultralytics.utils.loss.TVPDetectLoss", "ultralytics.utils.loss.TVPSegmentLoss" ], "chunk_id": "function_guess_model_scale_87aa6935" }, { "content": "def guess_model_task(model):\n \"\"\"\n Guess the task of a PyTorch model from its architecture or configuration.\n\n Args:\n model (torch.nn.Module | dict): PyTorch model or model configuration in YAML format.\n\n Returns:\n (str): Task of the model ('detect', 'segment', 'classify', 'pose', 'obb').\n \"\"\"\n\n def cfg2task(cfg):\n \"\"\"Guess from YAML dictionary.\"\"\"\n m = cfg[\"head\"][-1][-2].lower() # output module name\n if m in {\"classify\", \"classifier\", \"cls\", \"fc\"}:\n return \"classify\"\n if \"detect\" in m:\n return \"detect\"\n if \"segment\" in m:\n return \"segment\"\n if m == \"pose\":\n return \"pose\"\n if m == \"obb\":\n return \"obb\"\n\n # Guess from model cfg\n if isinstance(model, dict):\n with contextlib.suppress(Exception):\n return cfg2task(model)\n # Guess from PyTorch model\n if isinstance(model, torch.nn.Module): # PyTorch model\n for x in \"model.args\", \"model.model.args\", \"model.model.model.args\":\n with contextlib.suppress(Exception):\n return eval(x)[\"task\"]\n for x in \"model.yaml\", \"model.model.yaml\", \"model.model.model.yaml\":\n with contextlib.suppress(Exception):\n return cfg2task(eval(x))\n for m in model.modules():\n if isinstance(m, (Segment, YOLOESegment)):\n return \"segment\"\n elif isinstance(m, Classify):\n return \"classify\"\n elif isinstance(m, Pose):\n return \"pose\"\n elif isinstance(m, OBB):\n return \"obb\"\n elif isinstance(m, (Detect, WorldDetect, YOLOEDetect, v10Detect)):\n return \"detect\"\n\n # Guess from model filename\n if isinstance(model, (str, Path)):\n model = Path(model)\n if \"-seg\" in model.stem or \"segment\" in model.parts:\n return \"segment\"\n elif \"-cls\" in model.stem or \"classify\" in model.parts:\n return \"classify\"\n elif \"-pose\" in model.stem or \"pose\" in model.parts:\n return \"pose\"\n elif \"-obb\" in model.stem or \"obb\" in model.parts:\n return \"obb\"\n elif \"detect\" in model.parts:\n return \"detect\"\n\n # Unable to determine task from model\n LOGGER.warning(\n \"Unable to automatically guess model task, assuming 'task=detect'. \"\n \"Explicitly define task for your model, i.e. 'task=detect', 'segment', 'classify','pose' or 'obb'.\"\n )\n return \"detect\" # assume detect", "chunk_type": "function", "name": "guess_model_task", "file_path": "ultralytics\\ultralytics\\nn\\tasks.py", "start_line": 1794, "end_line": 1862, "start_col": 0, "end_col": 19, "parent_name": null, "docstring": "Guess the task of a PyTorch model from its architecture or configuration.\n\nArgs:\n model (torch.nn.Module | dict): PyTorch model or model configuration in YAML format.\n\nReturns:\n (str): Task of the model ('detect', 'segment', 'classify', 'pose', 'obb').", "parameters": [ "model" ], "return_type": null, "decorators": [], "complexity_score": 22, "dependencies": [ "contextlib", "pickle", "re", "types", "copy.deepcopy", "pathlib.Path", "torch", "torch.nn", "ultralytics.nn.autobackend.check_class_names", "ultralytics.nn.modules.AIFI", "ultralytics.nn.modules.C1", "ultralytics.nn.modules.C2", "ultralytics.nn.modules.C2PSA", "ultralytics.nn.modules.C3", "ultralytics.nn.modules.C3TR", "ultralytics.nn.modules.ELAN1", "ultralytics.nn.modules.OBB", "ultralytics.nn.modules.PSA", "ultralytics.nn.modules.SPP", "ultralytics.nn.modules.SPPELAN", "ultralytics.nn.modules.SPPF", "ultralytics.nn.modules.A2C2f", "ultralytics.nn.modules.AConv", "ultralytics.nn.modules.ADown", "ultralytics.nn.modules.Bottleneck", "ultralytics.nn.modules.BottleneckCSP", "ultralytics.nn.modules.C2f", "ultralytics.nn.modules.C2fAttn", "ultralytics.nn.modules.C2fCIB", "ultralytics.nn.modules.C2fPSA", "ultralytics.nn.modules.C3Ghost", "ultralytics.nn.modules.C3k2", "ultralytics.nn.modules.C3x", "ultralytics.nn.modules.CBFuse", "ultralytics.nn.modules.CBLinear", "ultralytics.nn.modules.Classify", "ultralytics.nn.modules.Concat", "ultralytics.nn.modules.Conv", "ultralytics.nn.modules.Conv2", "ultralytics.nn.modules.ConvTranspose", "ultralytics.nn.modules.Detect", "ultralytics.nn.modules.DWConv", "ultralytics.nn.modules.DWConvTranspose2d", "ultralytics.nn.modules.Focus", "ultralytics.nn.modules.GhostBottleneck", "ultralytics.nn.modules.GhostConv", "ultralytics.nn.modules.HGBlock", "ultralytics.nn.modules.HGStem", "ultralytics.nn.modules.ImagePoolingAttn", "ultralytics.nn.modules.Index", "ultralytics.nn.modules.LRPCHead", "ultralytics.nn.modules.Pose", "ultralytics.nn.modules.RepC3", "ultralytics.nn.modules.RepConv", "ultralytics.nn.modules.RepNCSPELAN4", "ultralytics.nn.modules.RepVGGDW", "ultralytics.nn.modules.ResNetLayer", "ultralytics.nn.modules.RTDETRDecoder", "ultralytics.nn.modules.SCDown", "ultralytics.nn.modules.Segment", "ultralytics.nn.modules.TorchVision", "ultralytics.nn.modules.WorldDetect", "ultralytics.nn.modules.YOLOEDetect", "ultralytics.nn.modules.YOLOESegment", "ultralytics.nn.modules.v10Detect", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.DEFAULT_CFG_KEYS", "ultralytics.utils.LOGGER", "ultralytics.utils.YAML", "ultralytics.utils.colorstr", "ultralytics.utils.emojis", "ultralytics.utils.checks.check_requirements", "ultralytics.utils.checks.check_suffix", "ultralytics.utils.checks.check_yaml", "ultralytics.utils.loss.E2EDetectLoss", "ultralytics.utils.loss.v8ClassificationLoss", "ultralytics.utils.loss.v8DetectionLoss", "ultralytics.utils.loss.v8OBBLoss", "ultralytics.utils.loss.v8PoseLoss", "ultralytics.utils.loss.v8SegmentationLoss", "ultralytics.utils.ops.make_divisible", "ultralytics.utils.patches.torch_load", "ultralytics.utils.plotting.feature_visualization", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "ultralytics.utils.torch_utils.fuse_deconv_and_bn", "ultralytics.utils.torch_utils.initialize_weights", "ultralytics.utils.torch_utils.intersect_dicts", "ultralytics.utils.torch_utils.model_info", "ultralytics.utils.torch_utils.scale_img", "ultralytics.utils.torch_utils.smart_inference_mode", "ultralytics.utils.torch_utils.time_sync", "sys", "importlib.import_module", "ultralytics.utils.downloads.attempt_download_asset", "ast", "ultralytics.models.utils.loss.RTDETRDetectionLoss", "ultralytics.nn.text_model.build_text_model", "ultralytics.nn.text_model.build_text_model", "thop", "ultralytics.utils.loss.TVPDetectLoss", "ultralytics.utils.loss.TVPSegmentLoss" ], "chunk_id": "function_guess_model_task_b8387d17" }, { "content": "from abc import abstractmethod", "chunk_type": "import", "name": "abstractmethod", "file_path": "ultralytics\\ultralytics\\nn\\text_model.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 30, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_abstractmethod_d1f3072b" }, { "content": "from pathlib import Path", "chunk_type": "import", "name": "Path", "file_path": "ultralytics\\ultralytics\\nn\\text_model.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Path_f4ff7c78" }, { "content": "from typing import List, Union", "chunk_type": "import", "name": "List, Union", "file_path": "ultralytics\\ultralytics\\nn\\text_model.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 30, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_List, Union_b45adaab" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\ultralytics\\nn\\text_model.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_d7244787" }, { "content": "import torch.nn as nn", "chunk_type": "import", "name": "torch.nn", "file_path": "ultralytics\\ultralytics\\nn\\text_model.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 21, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch.nn_5e0c6f99" }, { "content": "from PIL import Image", "chunk_type": "import", "name": "Image", "file_path": "ultralytics\\ultralytics\\nn\\text_model.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 21, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Image_10acb0e7" }, { "content": "from ultralytics.utils import checks", "chunk_type": "import", "name": "checks", "file_path": "ultralytics\\ultralytics\\nn\\text_model.py", "start_line": 11, "end_line": 11, "start_col": 0, "end_col": 36, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_checks_15e38da2" }, { "content": "from ultralytics.utils.torch_utils import smart_inference_mode", "chunk_type": "import", "name": "smart_inference_mode", "file_path": "ultralytics\\ultralytics\\nn\\text_model.py", "start_line": 12, "end_line": 12, "start_col": 0, "end_col": 62, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_smart_inference_mode_dcc0d413" }, { "content": "class TextModel(nn.Module):\n \"\"\"\n Abstract base class for text encoding models.\n\n This class defines the interface for text encoding models used in vision-language tasks. Subclasses must implement\n the tokenize and encode_text methods to provide text tokenization and encoding functionality.\n\n Methods:\n tokenize: Convert input texts to tokens for model processing.\n encode_text: Encode tokenized texts into normalized feature vectors.\n \"\"\"\n\n def __init__(self):\n \"\"\"Initialize the TextModel base class.\"\"\"\n super().__init__()\n\n @abstractmethod\n def tokenize(self, texts):\n \"\"\"Convert input texts to tokens for model processing.\"\"\"\n pass\n\n @abstractmethod\n def encode_text(self, texts, dtype):\n \"\"\"Encode tokenized texts into normalized feature vectors.\"\"\"\n pass", "chunk_type": "class", "name": "TextModel", "file_path": "ultralytics\\ultralytics\\nn\\text_model.py", "start_line": 21, "end_line": 45, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": "Abstract base class for text encoding models.\n\nThis class defines the interface for text encoding models used in vision-language tasks. Subclasses must implement\nthe tokenize and encode_text methods to provide text tokenization and encoding functionality.\n\nMethods:\n tokenize: Convert input texts to tokens for model processing.\n encode_text: Encode tokenized texts into normalized feature vectors.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "abc.abstractmethod", "pathlib.Path", "typing.List", "typing.Union", "torch", "torch.nn", "PIL.Image", "ultralytics.utils.checks", "ultralytics.utils.torch_utils.smart_inference_mode", "clip", "clip", "ultralytics.utils.downloads.attempt_download_asset", "warnings", "ultralytics.download", "mobileclip", "mobileclip", "nn.Module" ], "chunk_id": "class_TextModel_5daa136f" }, { "content": "class CLIP(TextModel):\n \"\"\"\n Implements OpenAI's CLIP (Contrastive Language-Image Pre-training) text encoder.\n\n This class provides a text encoder based on OpenAI's CLIP model, which can convert text into feature vectors\n that are aligned with corresponding image features in a shared embedding space.\n\n Attributes:\n model (clip.model.CLIP): The loaded CLIP model.\n device (torch.device): Device where the model is loaded.\n\n Methods:\n tokenize: Convert input texts to CLIP tokens.\n encode_text: Encode tokenized texts into normalized feature vectors.\n\n Examples:\n >>> import torch\n >>> device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n >>> clip_model = CLIP(size=\"ViT-B/32\", device=device)\n >>> tokens = clip_model.tokenize([\"a photo of a cat\", \"a photo of a dog\"])\n >>> text_features = clip_model.encode_text(tokens)\n >>> print(text_features.shape)\n \"\"\"\n\n def __init__(self, size: str, device: torch.device) -> None:\n \"\"\"\n Initialize the CLIP text encoder.\n\n This class implements the TextModel interface using OpenAI's CLIP model for text encoding. It loads\n a pre-trained CLIP model of the specified size and prepares it for text encoding tasks.\n\n Args:\n size (str): Model size identifier (e.g., 'ViT-B/32').\n device (torch.device): Device to load the model on.\n\n Examples:\n >>> import torch\n >>> clip_model = CLIP(\"ViT-B/32\", device=torch.device(\"cuda:0\"))\n >>> text_features = clip_model.encode_text([\"a photo of a cat\", \"a photo of a dog\"])\n \"\"\"\n super().__init__()\n self.model, self.image_preprocess = clip.load(size, device=device)\n self.to(device)\n self.device = device\n self.eval()\n\n def tokenize(self, texts: Union[str, List[str]]) -> torch.Tensor:\n \"\"\"\n Convert input texts to CLIP tokens.\n\n Args:\n texts (str | List[str]): Input text or list of texts to tokenize.\n\n Returns:\n (torch.Tensor): Tokenized text tensor with shape (batch_size, context_length) ready for model processing.\n\n Examples:\n >>> model = CLIP(\"ViT-B/32\", device=\"cpu\")\n >>> tokens = model.tokenize(\"a photo of a cat\")\n >>> print(tokens.shape) # torch.Size([1, 77])\n \"\"\"\n return clip.tokenize(texts).to(self.device)\n\n @smart_inference_mode()\n def encode_text(self, texts: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor:\n \"\"\"\n Encode tokenized texts into normalized feature vectors.\n\n This method processes tokenized text inputs through the CLIP model to generate feature vectors, which are then\n normalized to unit length. These normalized vectors can be used for text-image similarity comparisons.\n\n Args:\n texts (torch.Tensor): Tokenized text inputs, typically created using the tokenize() method.\n dtype (torch.dtype, optional): Data type for output features.\n\n Returns:\n (torch.Tensor): Normalized text feature vectors with unit length (L2 norm = 1).\n\n Examples:\n >>> clip_model = CLIP(\"ViT-B/32\", device=\"cuda\")\n >>> tokens = clip_model.tokenize([\"a photo of a cat\", \"a photo of a dog\"])\n >>> features = clip_model.encode_text(tokens)\n >>> features.shape\n torch.Size([2, 512])\n \"\"\"\n txt_feats = self.model.encode_text(texts).to(dtype)\n txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True)\n return txt_feats\n\n @smart_inference_mode()\n def encode_image(self, image: Union[Image.Image, torch.Tensor], dtype: torch.dtype = torch.float32) -> torch.Tensor:\n \"\"\"\n Encode preprocessed images into normalized feature vectors.\n\n This method processes preprocessed image inputs through the CLIP model to generate feature vectors, which are then\n normalized to unit length. These normalized vectors can be used for text-image similarity comparisons.\n\n Args:\n image (PIL.Image | torch.Tensor): Preprocessed image input. If a PIL Image is provided, it will be\n converted to a tensor using the model's image preprocessing function.\n dtype (torch.dtype, optional): Data type for output features.\n\n Returns:\n (torch.Tensor): Normalized image feature vectors with unit length (L2 norm = 1).\n\n Examples:\n >>> from ultralytics.nn.text_model import CLIP\n >>> from PIL import Image\n >>> clip_model = CLIP(\"ViT-B/32\", device=\"cuda\")\n >>> image = Image.open(\"path/to/image.jpg\")\n >>> image_tensor = clip_model.image_preprocess(image).unsqueeze(0).to(\"cuda\")\n >>> features = clip_model.encode_image(image_tensor)\n >>> features.shape\n torch.Size([1, 512])\n \"\"\"\n if isinstance(image, Image.Image):\n image = self.image_preprocess(image).unsqueeze(0).to(self.device)\n img_feats = self.model.encode_image(image).to(dtype)\n img_feats = img_feats / img_feats.norm(p=2, dim=-1, keepdim=True)\n return img_feats", "chunk_type": "class", "name": "CLIP", "file_path": "ultralytics\\ultralytics\\nn\\text_model.py", "start_line": 48, "end_line": 167, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": "Implements OpenAI's CLIP (Contrastive Language-Image Pre-training) text encoder.\n\nThis class provides a text encoder based on OpenAI's CLIP model, which can convert text into feature vectors\nthat are aligned with corresponding image features in a shared embedding space.\n\nAttributes:\n model (clip.model.CLIP): The loaded CLIP model.\n device (torch.device): Device where the model is loaded.\n\nMethods:\n tokenize: Convert input texts to CLIP tokens.\n encode_text: Encode tokenized texts into normalized feature vectors.\n\nExamples:\n >>> import torch\n >>> device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n >>> clip_model = CLIP(size=\"ViT-B/32\", device=device)\n >>> tokens = clip_model.tokenize([\"a photo of a cat\", \"a photo of a dog\"])\n >>> text_features = clip_model.encode_text(tokens)\n >>> print(text_features.shape)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "abc.abstractmethod", "pathlib.Path", "typing.List", "typing.Union", "torch", "torch.nn", "PIL.Image", "ultralytics.utils.checks", "ultralytics.utils.torch_utils.smart_inference_mode", "clip", "clip", "ultralytics.utils.downloads.attempt_download_asset", "warnings", "ultralytics.download", "mobileclip", "mobileclip", "TextModel" ], "chunk_id": "class_CLIP_1ecef4eb" }, { "content": "class MobileCLIP(TextModel):\n \"\"\"\n Implement Apple's MobileCLIP text encoder for efficient text encoding.\n\n This class implements the TextModel interface using Apple's MobileCLIP model, providing efficient text encoding\n capabilities for vision-language tasks with reduced computational requirements compared to standard CLIP models.\n\n Attributes:\n model (mobileclip.model.MobileCLIP): The loaded MobileCLIP model.\n tokenizer (callable): Tokenizer function for processing text inputs.\n device (torch.device): Device where the model is loaded.\n config_size_map (dict): Mapping from size identifiers to model configuration names.\n\n Methods:\n tokenize: Convert input texts to MobileCLIP tokens.\n encode_text: Encode tokenized texts into normalized feature vectors.\n\n Examples:\n >>> device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n >>> text_encoder = MobileCLIP(size=\"s0\", device=device)\n >>> tokens = text_encoder.tokenize([\"a photo of a cat\", \"a photo of a dog\"])\n >>> features = text_encoder.encode_text(tokens)\n \"\"\"\n\n config_size_map = {\"s0\": \"s0\", \"s1\": \"s1\", \"s2\": \"s2\", \"b\": \"b\", \"blt\": \"b\"}\n\n def __init__(self, size: str, device: torch.device) -> None:\n \"\"\"\n Initialize the MobileCLIP text encoder.\n\n This class implements the TextModel interface using Apple's MobileCLIP model for efficient text encoding.\n\n Args:\n size (str): Model size identifier (e.g., 's0', 's1', 's2', 'b', 'blt').\n device (torch.device): Device to load the model on.\n\n Examples:\n >>> import torch\n >>> model = MobileCLIP(\"s0\", device=torch.device(\"cpu\"))\n >>> tokens = model.tokenize([\"a photo of a cat\", \"a photo of a dog\"])\n >>> features = model.encode_text(tokens)\n \"\"\"\n try:\n import warnings\n\n # Suppress 'timm.models.layers is deprecated, please import via timm.layers' warning from mobileclip usage\n with warnings.catch_warnings():\n warnings.filterwarnings(\"ignore\", category=FutureWarning)\n import mobileclip\n except ImportError:\n # Ultralytics fork preferred since Apple MobileCLIP repo has incorrect version of torchvision\n checks.check_requirements(\"git+https://github.com/ultralytics/mobileclip.git\")\n import mobileclip\n\n super().__init__()\n config = self.config_size_map[size]\n file = f\"mobileclip_{size}.pt\"\n if not Path(file).is_file():\n from ultralytics import download\n\n download(f\"https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/{file}\")\n self.model = mobileclip.create_model_and_transforms(f\"mobileclip_{config}\", pretrained=file, device=device)[0]\n self.tokenizer = mobileclip.get_tokenizer(f\"mobileclip_{config}\")\n self.to(device)\n self.device = device\n self.eval()\n\n def tokenize(self, texts: List[str]) -> torch.Tensor:\n \"\"\"\n Convert input texts to MobileCLIP tokens.\n\n Args:\n texts (List[str]): List of text strings to tokenize.\n\n Returns:\n (torch.Tensor): Tokenized text inputs with shape (batch_size, sequence_length).\n\n Examples:\n >>> model = MobileCLIP(\"s0\", \"cpu\")\n >>> tokens = model.tokenize([\"a photo of a cat\", \"a photo of a dog\"])\n \"\"\"\n return self.tokenizer(texts).to(self.device)\n\n @smart_inference_mode()\n def encode_text(self, texts: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor:\n \"\"\"\n Encode tokenized texts into normalized feature vectors.\n\n Args:\n texts (torch.Tensor): Tokenized text inputs.\n dtype (torch.dtype, optional): Data type for output features.\n\n Returns:\n (torch.Tensor): Normalized text feature vectors with L2 normalization applied.\n\n Examples:\n >>> model = MobileCLIP(\"s0\", device=\"cpu\")\n >>> tokens = model.tokenize([\"a photo of a cat\", \"a photo of a dog\"])\n >>> features = model.encode_text(tokens)\n >>> features.shape\n torch.Size([2, 512]) # Actual dimension depends on model size\n \"\"\"\n text_features = self.model.encode_text(texts).to(dtype)\n text_features /= text_features.norm(p=2, dim=-1, keepdim=True)\n return text_features", "chunk_type": "class", "name": "MobileCLIP", "file_path": "ultralytics\\ultralytics\\nn\\text_model.py", "start_line": 170, "end_line": 274, "start_col": 0, "end_col": 28, "parent_name": null, "docstring": "Implement Apple's MobileCLIP text encoder for efficient text encoding.\n\nThis class implements the TextModel interface using Apple's MobileCLIP model, providing efficient text encoding\ncapabilities for vision-language tasks with reduced computational requirements compared to standard CLIP models.\n\nAttributes:\n model (mobileclip.model.MobileCLIP): The loaded MobileCLIP model.\n tokenizer (callable): Tokenizer function for processing text inputs.\n device (torch.device): Device where the model is loaded.\n config_size_map (dict): Mapping from size identifiers to model configuration names.\n\nMethods:\n tokenize: Convert input texts to MobileCLIP tokens.\n encode_text: Encode tokenized texts into normalized feature vectors.\n\nExamples:\n >>> device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n >>> text_encoder = MobileCLIP(size=\"s0\", device=device)\n >>> tokens = text_encoder.tokenize([\"a photo of a cat\", \"a photo of a dog\"])\n >>> features = text_encoder.encode_text(tokens)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "abc.abstractmethod", "pathlib.Path", "typing.List", "typing.Union", "torch", "torch.nn", "PIL.Image", "ultralytics.utils.checks", "ultralytics.utils.torch_utils.smart_inference_mode", "clip", "clip", "ultralytics.utils.downloads.attempt_download_asset", "warnings", "ultralytics.download", "mobileclip", "mobileclip", "TextModel" ], "chunk_id": "class_MobileCLIP_254eb3f3" }, { "content": "class MobileCLIPTS(TextModel):\n \"\"\"\n Load a TorchScript traced version of MobileCLIP.\n\n This class implements the TextModel interface using Apple's MobileCLIP model in TorchScript format, providing\n efficient text encoding capabilities for vision-language tasks with optimized inference performance.\n\n Attributes:\n encoder (torch.jit.ScriptModule): The loaded TorchScript MobileCLIP text encoder.\n tokenizer (callable): Tokenizer function for processing text inputs.\n device (torch.device): Device where the model is loaded.\n\n Methods:\n tokenize: Convert input texts to MobileCLIP tokens.\n encode_text: Encode tokenized texts into normalized feature vectors.\n\n Examples:\n >>> device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n >>> text_encoder = MobileCLIPTS(device=device)\n >>> tokens = text_encoder.tokenize([\"a photo of a cat\", \"a photo of a dog\"])\n >>> features = text_encoder.encode_text(tokens)\n \"\"\"\n\n def __init__(self, device: torch.device):\n \"\"\"\n Initialize the MobileCLIP TorchScript text encoder.\n\n This class implements the TextModel interface using Apple's MobileCLIP model in TorchScript format for\n efficient text encoding with optimized inference performance.\n\n Args:\n device (torch.device): Device to load the model on.\n\n Examples:\n >>> model = MobileCLIPTS(device=torch.device(\"cpu\"))\n >>> tokens = model.tokenize([\"a photo of a cat\", \"a photo of a dog\"])\n >>> features = model.encode_text(tokens)\n \"\"\"\n super().__init__()\n from ultralytics.utils.downloads import attempt_download_asset\n\n self.encoder = torch.jit.load(attempt_download_asset(\"mobileclip_blt.ts\"), map_location=device)\n self.tokenizer = clip.clip.tokenize\n self.device = device\n\n def tokenize(self, texts: List[str]) -> torch.Tensor:\n \"\"\"\n Convert input texts to MobileCLIP tokens.\n\n Args:\n texts (List[str]): List of text strings to tokenize.\n\n Returns:\n (torch.Tensor): Tokenized text inputs with shape (batch_size, sequence_length).\n\n Examples:\n >>> model = MobileCLIPTS(\"cpu\")\n >>> tokens = model.tokenize([\"a photo of a cat\", \"a photo of a dog\"])\n \"\"\"\n return self.tokenizer(texts).to(self.device)\n\n @smart_inference_mode()\n def encode_text(self, texts: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor:\n \"\"\"\n Encode tokenized texts into normalized feature vectors.\n\n Args:\n texts (torch.Tensor): Tokenized text inputs.\n dtype (torch.dtype, optional): Data type for output features.\n\n Returns:\n (torch.Tensor): Normalized text feature vectors with L2 normalization applied.\n\n Examples:\n >>> model = MobileCLIPTS(device=\"cpu\")\n >>> tokens = model.tokenize([\"a photo of a cat\", \"a photo of a dog\"])\n >>> features = model.encode_text(tokens)\n >>> features.shape\n torch.Size([2, 512]) # Actual dimension depends on model size\n \"\"\"\n # NOTE: no need to do normalization here as it's embedded in the torchscript model\n return self.encoder(texts).to(dtype)", "chunk_type": "class", "name": "MobileCLIPTS", "file_path": "ultralytics\\ultralytics\\nn\\text_model.py", "start_line": 277, "end_line": 358, "start_col": 0, "end_col": 44, "parent_name": null, "docstring": "Load a TorchScript traced version of MobileCLIP.\n\nThis class implements the TextModel interface using Apple's MobileCLIP model in TorchScript format, providing\nefficient text encoding capabilities for vision-language tasks with optimized inference performance.\n\nAttributes:\n encoder (torch.jit.ScriptModule): The loaded TorchScript MobileCLIP text encoder.\n tokenizer (callable): Tokenizer function for processing text inputs.\n device (torch.device): Device where the model is loaded.\n\nMethods:\n tokenize: Convert input texts to MobileCLIP tokens.\n encode_text: Encode tokenized texts into normalized feature vectors.\n\nExamples:\n >>> device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n >>> text_encoder = MobileCLIPTS(device=device)\n >>> tokens = text_encoder.tokenize([\"a photo of a cat\", \"a photo of a dog\"])\n >>> features = text_encoder.encode_text(tokens)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "abc.abstractmethod", "pathlib.Path", "typing.List", "typing.Union", "torch", "torch.nn", "PIL.Image", "ultralytics.utils.checks", "ultralytics.utils.torch_utils.smart_inference_mode", "clip", "clip", "ultralytics.utils.downloads.attempt_download_asset", "warnings", "ultralytics.download", "mobileclip", "mobileclip", "TextModel" ], "chunk_id": "class_MobileCLIPTS_3d36ad14" }, { "content": "def build_text_model(variant: str, device: torch.device = None) -> TextModel:\n \"\"\"\n Build a text encoding model based on the specified variant.\n\n Args:\n variant (str): Model variant in format \"base:size\" (e.g., \"clip:ViT-B/32\" or \"mobileclip:s0\").\n device (torch.device, optional): Device to load the model on.\n\n Returns:\n (TextModel): Instantiated text encoding model.\n\n Examples:\n >>> model = build_text_model(\"clip:ViT-B/32\", device=torch.device(\"cuda\"))\n >>> model = build_text_model(\"mobileclip:s0\", device=torch.device(\"cpu\"))\n \"\"\"\n base, size = variant.split(\":\")\n if base == \"clip\":\n return CLIP(size, device)\n elif base == \"mobileclip\":\n return MobileCLIPTS(device)\n else:\n raise ValueError(f\"Unrecognized base model: '{base}'. Supported base models: 'clip', 'mobileclip'.\")", "chunk_type": "function", "name": "build_text_model", "file_path": "ultralytics\\ultralytics\\nn\\text_model.py", "start_line": 361, "end_line": 382, "start_col": 0, "end_col": 108, "parent_name": null, "docstring": "Build a text encoding model based on the specified variant.\n\nArgs:\n variant (str): Model variant in format \"base:size\" (e.g., \"clip:ViT-B/32\" or \"mobileclip:s0\").\n device (torch.device, optional): Device to load the model on.\n\nReturns:\n (TextModel): Instantiated text encoding model.\n\nExamples:\n >>> model = build_text_model(\"clip:ViT-B/32\", device=torch.device(\"cuda\"))\n >>> model = build_text_model(\"mobileclip:s0\", device=torch.device(\"cpu\"))", "parameters": [ "variant: str", "device: torch.device" ], "return_type": "TextModel", "decorators": [], "complexity_score": 3, "dependencies": [ "abc.abstractmethod", "pathlib.Path", "typing.List", "typing.Union", "torch", "torch.nn", "PIL.Image", "ultralytics.utils.checks", "ultralytics.utils.torch_utils.smart_inference_mode", "clip", "clip", "ultralytics.utils.downloads.attempt_download_asset", "warnings", "ultralytics.download", "mobileclip", "mobileclip" ], "chunk_id": "function_build_text_model_f8fd37be" }, { "content": "from .tasks import (\n BaseModel,\n ClassificationModel,\n DetectionModel,\n SegmentationModel,\n attempt_load_one_weight,\n attempt_load_weights,\n guess_model_scale,\n guess_model_task,\n parse_model,\n torch_safe_load,\n yaml_model_load,\n)", "chunk_type": "import", "name": "BaseModel, ClassificationModel, DetectionModel, SegmentationModel, attempt_load_one_weight, attempt_load_weights, guess_model_scale, guess_model_task, parse_model, torch_safe_load, yaml_model_load", "file_path": "ultralytics\\ultralytics\\nn\\__init__.py", "start_line": 3, "end_line": 15, "start_col": 0, "end_col": 1, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_BaseModel, ClassificationModel, DetectionModel, SegmentationModel, attempt_load_one_weight, attempt_load_weights, guess_model_scale, guess_model_task, parse_model, torch_safe_load, yaml_model_load_ebc4913f" }, { "content": "__all__ = (\n \"attempt_load_one_weight\",\n \"attempt_load_weights\",\n \"parse_model\",\n \"yaml_model_load\",\n \"guess_model_task\",\n \"guess_model_scale\",\n \"torch_safe_load\",\n \"DetectionModel\",\n \"SegmentationModel\",\n \"ClassificationModel\",\n \"BaseModel\",\n)", "chunk_type": "variable", "name": "__all__", "file_path": "ultralytics\\ultralytics\\nn\\__init__.py", "start_line": 17, "end_line": 29, "start_col": 0, "end_col": 1, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable___all___f7f282f7" }, { "content": "from collections import defaultdict", "chunk_type": "import", "name": "defaultdict", "file_path": "ultralytics\\ultralytics\\solutions\\ai_gym.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 35, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_defaultdict_3487370c" }, { "content": "from typing import Any", "chunk_type": "import", "name": "Any", "file_path": "ultralytics\\ultralytics\\solutions\\ai_gym.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 22, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Any_46c335ea" }, { "content": "from ultralytics.solutions.solutions import BaseSolution, SolutionAnnotator, SolutionResults", "chunk_type": "import", "name": "BaseSolution, SolutionAnnotator, SolutionResults", "file_path": "ultralytics\\ultralytics\\solutions\\ai_gym.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 92, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_BaseSolution, SolutionAnnotator, SolutionResults_1463f35c" }, { "content": "class AIGym(BaseSolution):\n \"\"\"\n A class to manage gym steps of people in a real-time video stream based on their poses.\n\n This class extends BaseSolution to monitor workouts using YOLO pose estimation models. It tracks and counts\n repetitions of exercises based on predefined angle thresholds for up and down positions.\n\n Attributes:\n states (Dict[float, int, str]): Stores per-track angle, count, and stage for workout monitoring.\n up_angle (float): Angle threshold for considering the 'up' position of an exercise.\n down_angle (float): Angle threshold for considering the 'down' position of an exercise.\n kpts (List[int]): Indices of keypoints used for angle calculation.\n\n Methods:\n process: Process a frame to detect poses, calculate angles, and count repetitions.\n\n Examples:\n >>> gym = AIGym(model=\"yolo11n-pose.pt\")\n >>> image = cv2.imread(\"gym_scene.jpg\")\n >>> results = gym.process(image)\n >>> processed_image = results.plot_im\n >>> cv2.imshow(\"Processed Image\", processed_image)\n >>> cv2.waitKey(0)\n \"\"\"\n\n def __init__(self, **kwargs: Any) -> None:\n \"\"\"\n Initialize AIGym for workout monitoring using pose estimation and predefined angles.\n\n Args:\n **kwargs (Any): Keyword arguments passed to the parent class constructor.\n model (str): Model name or path, defaults to \"yolo11n-pose.pt\".\n \"\"\"\n kwargs[\"model\"] = kwargs.get(\"model\", \"yolo11n-pose.pt\")\n super().__init__(**kwargs)\n self.states = defaultdict(lambda: {\"angle\": 0, \"count\": 0, \"stage\": \"-\"}) # Dict for count, angle and stage\n\n # Extract details from CFG single time for usage later\n self.up_angle = float(self.CFG[\"up_angle\"]) # Pose up predefined angle to consider up pose\n self.down_angle = float(self.CFG[\"down_angle\"]) # Pose down predefined angle to consider down pose\n self.kpts = self.CFG[\"kpts\"] # User selected kpts of workouts storage for further usage\n\n def process(self, im0) -> SolutionResults:\n \"\"\"\n Monitor workouts using Ultralytics YOLO Pose Model.\n\n This function processes an input image to track and analyze human poses for workout monitoring. It uses\n the YOLO Pose model to detect keypoints, estimate angles, and count repetitions based on predefined\n angle thresholds.\n\n Args:\n im0 (np.ndarray): Input image for processing.\n\n Returns:\n (SolutionResults): Contains processed image `plot_im`,\n 'workout_count' (list of completed reps),\n 'workout_stage' (list of current stages),\n 'workout_angle' (list of angles), and\n 'total_tracks' (total number of tracked individuals).\n\n Examples:\n >>> gym = AIGym()\n >>> image = cv2.imread(\"workout.jpg\")\n >>> results = gym.process(image)\n >>> processed_image = results.plot_im\n \"\"\"\n annotator = SolutionAnnotator(im0, line_width=self.line_width) # Initialize annotator\n\n self.extract_tracks(im0) # Extract tracks (bounding boxes, classes, and masks)\n\n if len(self.boxes):\n kpt_data = self.tracks.keypoints.data\n\n for i, k in enumerate(kpt_data):\n state = self.states[self.track_ids[i]] # get state details\n # Get keypoints and estimate the angle\n state[\"angle\"] = annotator.estimate_pose_angle(*[k[int(idx)] for idx in self.kpts])\n annotator.draw_specific_kpts(k, self.kpts, radius=self.line_width * 3)\n\n # Determine stage and count logic based on angle thresholds\n if state[\"angle\"] < self.down_angle:\n if state[\"stage\"] == \"up\":\n state[\"count\"] += 1\n state[\"stage\"] = \"down\"\n elif state[\"angle\"] > self.up_angle:\n state[\"stage\"] = \"up\"\n\n # Display angle, count, and stage text\n if self.show_labels:\n annotator.plot_angle_and_count_and_stage(\n angle_text=state[\"angle\"], # angle text for display\n count_text=state[\"count\"], # count text for workouts\n stage_text=state[\"stage\"], # stage position text\n center_kpt=k[int(self.kpts[1])], # center keypoint for display\n )\n plot_im = annotator.result()\n self.display_output(plot_im) # Display output image, if environment support display\n\n # Return SolutionResults\n return SolutionResults(\n plot_im=plot_im,\n workout_count=[v[\"count\"] for v in self.states.values()],\n workout_stage=[v[\"stage\"] for v in self.states.values()],\n workout_angle=[v[\"angle\"] for v in self.states.values()],\n total_tracks=len(self.track_ids),\n )", "chunk_type": "class", "name": "AIGym", "file_path": "ultralytics\\ultralytics\\solutions\\ai_gym.py", "start_line": 9, "end_line": 114, "start_col": 0, "end_col": 9, "parent_name": null, "docstring": "A class to manage gym steps of people in a real-time video stream based on their poses.\n\nThis class extends BaseSolution to monitor workouts using YOLO pose estimation models. It tracks and counts\nrepetitions of exercises based on predefined angle thresholds for up and down positions.\n\nAttributes:\n states (Dict[float, int, str]): Stores per-track angle, count, and stage for workout monitoring.\n up_angle (float): Angle threshold for considering the 'up' position of an exercise.\n down_angle (float): Angle threshold for considering the 'down' position of an exercise.\n kpts (List[int]): Indices of keypoints used for angle calculation.\n\nMethods:\n process: Process a frame to detect poses, calculate angles, and count repetitions.\n\nExamples:\n >>> gym = AIGym(model=\"yolo11n-pose.pt\")\n >>> image = cv2.imread(\"gym_scene.jpg\")\n >>> results = gym.process(image)\n >>> processed_image = results.plot_im\n >>> cv2.imshow(\"Processed Image\", processed_image)\n >>> cv2.waitKey(0)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "collections.defaultdict", "typing.Any", "ultralytics.solutions.solutions.BaseSolution", "ultralytics.solutions.solutions.SolutionAnnotator", "ultralytics.solutions.solutions.SolutionResults", "BaseSolution" ], "chunk_id": "class_AIGym_2d048bb0" }, { "content": "from itertools import cycle", "chunk_type": "import", "name": "cycle", "file_path": "ultralytics\\ultralytics\\solutions\\analytics.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 27, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_cycle_4867b172" }, { "content": "from typing import Any, Dict, Optional", "chunk_type": "import", "name": "Any, Dict, Optional", "file_path": "ultralytics\\ultralytics\\solutions\\analytics.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 38, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Any, Dict, Optional_f2a819f7" }, { "content": "import cv2", "chunk_type": "import", "name": "cv2", "file_path": "ultralytics\\ultralytics\\solutions\\analytics.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 10, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_cv2_0011d702" }, { "content": "import numpy as np", "chunk_type": "import", "name": "numpy", "file_path": "ultralytics\\ultralytics\\solutions\\analytics.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_numpy_ecb3806f" }, { "content": "from ultralytics.solutions.solutions import BaseSolution, SolutionResults # Import a parent class", "chunk_type": "import", "name": "BaseSolution, SolutionResults", "file_path": "ultralytics\\ultralytics\\solutions\\analytics.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 73, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_BaseSolution, SolutionResults_b6fd5bee" }, { "content": "class Analytics(BaseSolution):\n \"\"\"\n A class for creating and updating various types of charts for visual analytics.\n\n This class extends BaseSolution to provide functionality for generating line, bar, pie, and area charts\n based on object detection and tracking data.\n\n Attributes:\n type (str): The type of analytics chart to generate ('line', 'bar', 'pie', or 'area').\n x_label (str): Label for the x-axis.\n y_label (str): Label for the y-axis.\n bg_color (str): Background color of the chart frame.\n fg_color (str): Foreground color of the chart frame.\n title (str): Title of the chart window.\n max_points (int): Maximum number of data points to display on the chart.\n fontsize (int): Font size for text display.\n color_cycle (cycle): Cyclic iterator for chart colors.\n total_counts (int): Total count of detected objects (used for line charts).\n clswise_count (Dict[str, int]): Dictionary for class-wise object counts.\n fig (Figure): Matplotlib figure object for the chart.\n ax (Axes): Matplotlib axes object for the chart.\n canvas (FigureCanvasAgg): Canvas for rendering the chart.\n lines (dict): Dictionary to store line objects for area charts.\n color_mapping (Dict[str, str]): Dictionary mapping class labels to colors for consistent visualization.\n\n Methods:\n process: Process image data and update the chart.\n update_graph: Update the chart with new data points.\n\n Examples:\n >>> analytics = Analytics(analytics_type=\"line\")\n >>> frame = cv2.imread(\"image.jpg\")\n >>> results = analytics.process(frame, frame_number=1)\n >>> cv2.imshow(\"Analytics\", results.plot_im)\n \"\"\"\n\n def __init__(self, **kwargs: Any) -> None:\n \"\"\"Initialize Analytics class with various chart types for visual data representation.\"\"\"\n super().__init__(**kwargs)\n\n import matplotlib.pyplot as plt # scope for faster 'import ultralytics'\n from matplotlib.backends.backend_agg import FigureCanvasAgg\n from matplotlib.figure import Figure\n\n self.type = self.CFG[\"analytics_type\"] # type of analytics i.e \"line\", \"pie\", \"bar\" or \"area\" charts.\n self.x_label = \"Classes\" if self.type in {\"bar\", \"pie\"} else \"Frame#\"\n self.y_label = \"Total Counts\"\n\n # Predefined data\n self.bg_color = \"#F3F3F3\" # background color of frame\n self.fg_color = \"#111E68\" # foreground color of frame\n self.title = \"Ultralytics Solutions\" # window name\n self.max_points = 45 # maximum points to be drawn on window\n self.fontsize = 25 # text font size for display\n figsize = self.CFG[\"figsize\"] # set output image size i.e (12.8, 7.2) -> w = 1280, h = 720\n self.color_cycle = cycle([\"#DD00BA\", \"#042AFF\", \"#FF4447\", \"#7D24FF\", \"#BD00FF\"])\n\n self.total_counts = 0 # count variable for storing total counts i.e. for line\n self.clswise_count = {} # dictionary for class-wise counts\n self.update_every = kwargs.get(\"update_every\", 30) # Only update graph every 30 frames by default\n self.last_plot_im = None # Cache of the last rendered chart\n\n # Ensure line and area chart\n if self.type in {\"line\", \"area\"}:\n self.lines = {}\n self.fig = Figure(facecolor=self.bg_color, figsize=figsize)\n self.canvas = FigureCanvasAgg(self.fig) # Set common axis properties\n self.ax = self.fig.add_subplot(111, facecolor=self.bg_color)\n if self.type == \"line\":\n (self.line,) = self.ax.plot([], [], color=\"cyan\", linewidth=self.line_width)\n elif self.type in {\"bar\", \"pie\"}:\n # Initialize bar or pie plot\n self.fig, self.ax = plt.subplots(figsize=figsize, facecolor=self.bg_color)\n self.canvas = FigureCanvasAgg(self.fig) # Set common axis properties\n self.ax.set_facecolor(self.bg_color)\n self.color_mapping = {}\n\n if self.type == \"pie\": # Ensure pie chart is circular\n self.ax.axis(\"equal\")\n\n def process(self, im0: np.ndarray, frame_number: int) -> SolutionResults:\n \"\"\"\n Process image data and run object tracking to update analytics charts.\n\n Args:\n im0 (np.ndarray): Input image for processing.\n frame_number (int): Video frame number for plotting the data.\n\n Returns:\n (SolutionResults): Contains processed image `plot_im`, 'total_tracks' (int, total number of tracked objects)\n and 'classwise_count' (dict, per-class object count).\n\n Raises:\n ModuleNotFoundError: If an unsupported chart type is specified.\n\n Examples:\n >>> analytics = Analytics(analytics_type=\"line\")\n >>> frame = np.zeros((480, 640, 3), dtype=np.uint8)\n >>> results = analytics.process(frame, frame_number=1)\n \"\"\"\n self.extract_tracks(im0) # Extract tracks\n if self.type == \"line\":\n for _ in self.boxes:\n self.total_counts += 1\n update_required = frame_number % self.update_every == 0 or self.last_plot_im is None\n if update_required:\n self.last_plot_im = self.update_graph(frame_number=frame_number)\n plot_im = self.last_plot_im\n self.total_counts = 0\n elif self.type in {\"pie\", \"bar\", \"area\"}:\n from collections import Counter\n\n self.clswise_count = Counter(self.names[int(cls)] for cls in self.clss)\n update_required = frame_number % self.update_every == 0 or self.last_plot_im is None\n if update_required:\n self.last_plot_im = self.update_graph(\n frame_number=frame_number, count_dict=self.clswise_count, plot=self.type\n )\n plot_im = self.last_plot_im\n else:\n raise ModuleNotFoundError(f\"{self.type} chart is not supported ❌\")\n\n # return output dictionary with summary for more usage\n return SolutionResults(plot_im=plot_im, total_tracks=len(self.track_ids), classwise_count=self.clswise_count)\n\n def update_graph(\n self, frame_number: int, count_dict: Optional[Dict[str, int]] = None, plot: str = \"line\"\n ) -> np.ndarray:\n \"\"\"\n Update the graph with new data for single or multiple classes.\n\n Args:\n frame_number (int): The current frame number.\n count_dict (Dict[str, int], optional): Dictionary with class names as keys and counts as values for\n multiple classes. If None, updates a single line graph.\n plot (str): Type of the plot. Options are 'line', 'bar', 'pie', or 'area'.\n\n Returns:\n (np.ndarray): Updated image containing the graph.\n\n Examples:\n >>> analytics = Analytics(analytics_type=\"bar\")\n >>> frame_num = 10\n >>> results_dict = {\"person\": 5, \"car\": 3}\n >>> updated_image = analytics.update_graph(frame_num, results_dict, plot=\"bar\")\n \"\"\"\n if count_dict is None:\n # Single line update\n x_data = np.append(self.line.get_xdata(), float(frame_number))\n y_data = np.append(self.line.get_ydata(), float(self.total_counts))\n\n if len(x_data) > self.max_points:\n x_data, y_data = x_data[-self.max_points :], y_data[-self.max_points :]\n\n self.line.set_data(x_data, y_data)\n self.line.set_label(\"Counts\")\n self.line.set_color(\"#7b0068\") # Pink color\n self.line.set_marker(\"*\")\n self.line.set_markersize(self.line_width * 5)\n else:\n labels = list(count_dict.keys())\n counts = list(count_dict.values())\n if plot == \"area\":\n color_cycle = cycle([\"#DD00BA\", \"#042AFF\", \"#FF4447\", \"#7D24FF\", \"#BD00FF\"])\n # Multiple lines or area update\n x_data = self.ax.lines[0].get_xdata() if self.ax.lines else np.array([])\n y_data_dict = {key: np.array([]) for key in count_dict.keys()}\n if self.ax.lines:\n for line, key in zip(self.ax.lines, count_dict.keys()):\n y_data_dict[key] = line.get_ydata()\n\n x_data = np.append(x_data, float(frame_number))\n max_length = len(x_data)\n for key in count_dict.keys():\n y_data_dict[key] = np.append(y_data_dict[key], float(count_dict[key]))\n if len(y_data_dict[key]) < max_length:\n y_data_dict[key] = np.pad(y_data_dict[key], (0, max_length - len(y_data_dict[key])))\n if len(x_data) > self.max_points:\n x_data = x_data[1:]\n for key in count_dict.keys():\n y_data_dict[key] = y_data_dict[key][1:]\n\n self.ax.clear()\n for key, y_data in y_data_dict.items():\n color = next(color_cycle)\n self.ax.fill_between(x_data, y_data, color=color, alpha=0.55)\n self.ax.plot(\n x_data,\n y_data,\n color=color,\n linewidth=self.line_width,\n marker=\"o\",\n markersize=self.line_width * 5,\n label=f\"{key} Data Points\",\n )\n if plot == \"bar\":\n self.ax.clear() # clear bar data\n for label in labels: # Map labels to colors\n if label not in self.color_mapping:\n self.color_mapping[label] = next(self.color_cycle)\n colors = [self.color_mapping[label] for label in labels]\n bars = self.ax.bar(labels, counts, color=colors)\n for bar, count in zip(bars, counts):\n self.ax.text(\n bar.get_x() + bar.get_width() / 2,\n bar.get_height(),\n str(count),\n ha=\"center\",\n va=\"bottom\",\n color=self.fg_color,\n )\n # Create the legend using labels from the bars\n for bar, label in zip(bars, labels):\n bar.set_label(label) # Assign label to each bar\n self.ax.legend(loc=\"upper left\", fontsize=13, facecolor=self.fg_color, edgecolor=self.fg_color)\n if plot == \"pie\":\n total = sum(counts)\n percentages = [size / total * 100 for size in counts]\n start_angle = 90\n self.ax.clear()\n\n # Create pie chart and create legend labels with percentages\n wedges, _ = self.ax.pie(\n counts, labels=labels, startangle=start_angle, textprops={\"color\": self.fg_color}, autopct=None\n )\n legend_labels = [f\"{label} ({percentage:.1f}%)\" for label, percentage in zip(labels, percentages)]\n\n # Assign the legend using the wedges and manually created labels\n self.ax.legend(wedges, legend_labels, title=\"Classes\", loc=\"center left\", bbox_to_anchor=(1, 0, 0.5, 1))\n self.fig.subplots_adjust(left=0.1, right=0.75) # Adjust layout to fit the legend\n\n # Common plot settings\n self.ax.set_facecolor(\"#f0f0f0\") # Set to light gray or any other color you like\n self.ax.grid(True, linestyle=\"--\", linewidth=0.5, alpha=0.5) # Display grid for more data insights\n self.ax.set_title(self.title, color=self.fg_color, fontsize=self.fontsize)\n self.ax.set_xlabel(self.x_label, color=self.fg_color, fontsize=self.fontsize - 3)\n self.ax.set_ylabel(self.y_label, color=self.fg_color, fontsize=self.fontsize - 3)\n\n # Add and format legend\n legend = self.ax.legend(loc=\"upper left\", fontsize=13, facecolor=self.bg_color, edgecolor=self.bg_color)\n for text in legend.get_texts():\n text.set_color(self.fg_color)\n\n # Redraw graph, update view, capture, and display the updated plot\n self.ax.relim()\n self.ax.autoscale_view()\n self.canvas.draw()\n im0 = np.array(self.canvas.renderer.buffer_rgba())\n im0 = cv2.cvtColor(im0[:, :, :3], cv2.COLOR_RGBA2BGR)\n self.display_output(im0)\n\n return im0 # Return the image", "chunk_type": "class", "name": "Analytics", "file_path": "ultralytics\\ultralytics\\solutions\\analytics.py", "start_line": 12, "end_line": 263, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": "A class for creating and updating various types of charts for visual analytics.\n\nThis class extends BaseSolution to provide functionality for generating line, bar, pie, and area charts\nbased on object detection and tracking data.\n\nAttributes:\n type (str): The type of analytics chart to generate ('line', 'bar', 'pie', or 'area').\n x_label (str): Label for the x-axis.\n y_label (str): Label for the y-axis.\n bg_color (str): Background color of the chart frame.\n fg_color (str): Foreground color of the chart frame.\n title (str): Title of the chart window.\n max_points (int): Maximum number of data points to display on the chart.\n fontsize (int): Font size for text display.\n color_cycle (cycle): Cyclic iterator for chart colors.\n total_counts (int): Total count of detected objects (used for line charts).\n clswise_count (Dict[str, int]): Dictionary for class-wise object counts.\n fig (Figure): Matplotlib figure object for the chart.\n ax (Axes): Matplotlib axes object for the chart.\n canvas (FigureCanvasAgg): Canvas for rendering the chart.\n lines (dict): Dictionary to store line objects for area charts.\n color_mapping (Dict[str, str]): Dictionary mapping class labels to colors for consistent visualization.\n\nMethods:\n process: Process image data and update the chart.\n update_graph: Update the chart with new data points.\n\nExamples:\n >>> analytics = Analytics(analytics_type=\"line\")\n >>> frame = cv2.imread(\"image.jpg\")\n >>> results = analytics.process(frame, frame_number=1)\n >>> cv2.imshow(\"Analytics\", results.plot_im)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "itertools.cycle", "typing.Any", "typing.Dict", "typing.Optional", "cv2", "numpy", "ultralytics.solutions.solutions.BaseSolution", "ultralytics.solutions.solutions.SolutionResults", "matplotlib.pyplot", "matplotlib.backends.backend_agg.FigureCanvasAgg", "matplotlib.figure.Figure", "collections.Counter", "BaseSolution" ], "chunk_id": "class_Analytics_9ab5fae7" }, { "content": "from dataclasses import dataclass, field", "chunk_type": "import", "name": "dataclass, field", "file_path": "ultralytics\\ultralytics\\solutions\\config.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 40, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_dataclass, field_fed1434c" }, { "content": "from typing import Any, List, Optional, Tuple", "chunk_type": "import", "name": "Any, List, Optional, Tuple", "file_path": "ultralytics\\ultralytics\\solutions\\config.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 45, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Any, List, Optional, Tuple_6ab5d836" }, { "content": "import cv2", "chunk_type": "import", "name": "cv2", "file_path": "ultralytics\\ultralytics\\solutions\\config.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 10, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_cv2_2054aad3" }, { "content": "class SolutionConfig:\n \"\"\"\n Manages configuration parameters for Ultralytics Vision AI solutions.\n\n The SolutionConfig class serves as a centralized configuration container for all the\n Ultralytics solution modules: https://docs.ultralytics.com/solutions/#solutions.\n It leverages Python `dataclass` for clear, type-safe, and maintainable parameter definitions.\n\n Attributes:\n source (str, optional): Path to the input source (video, RTSP, etc.). Only usable with Solutions CLI.\n model (str, optional): Path to the Ultralytics YOLO model to be used for inference.\n classes (List[int], optional): List of class indices to filter detections.\n show_conf (bool): Whether to show confidence scores on the visual output.\n show_labels (bool): Whether to display class labels on visual output.\n region (List[Tuple[int, int]], optional): Polygonal region or line for object counting.\n colormap (int, optional): OpenCV colormap constant for visual overlays (e.g., cv2.COLORMAP_JET).\n show_in (bool): Whether to display count number for objects entering the region.\n show_out (bool): Whether to display count number for objects leaving the region.\n up_angle (float): Upper angle threshold used in pose-based workouts monitoring.\n down_angle (int): Lower angle threshold used in pose-based workouts monitoring.\n kpts (List[int]): Keypoint indices to monitor, e.g., for pose analytics.\n analytics_type (str): Type of analytics to perform (\"line\", \"area\", \"bar\", \"pie\", etc.).\n figsize (Tuple[int, int], optional): Size of the matplotlib figure used for analytical plots (width, height).\n blur_ratio (float): Ratio used to blur objects in the video frames (0.0 to 1.0).\n vision_point (Tuple[int, int]): Reference point for directional tracking or perspective drawing.\n crop_dir (str): Directory path to save cropped detection images.\n json_file (str): Path to a JSON file containing data for parking areas.\n line_width (int): Width for visual display i.e. bounding boxes, keypoints, counts.\n records (int): Number of detection records to send email alerts.\n fps (float): Frame rate (Frames Per Second) for speed estimation calculation.\n max_hist (int): Maximum number of historical points or states stored per tracked object for speed estimation.\n meter_per_pixel (float): Scale for real-world measurement, used in speed or distance calculations.\n max_speed (int): Maximum speed limit (e.g., km/h or mph) used in visual alerts or constraints.\n show (bool): Whether to display the visual output on screen.\n iou (float): Intersection-over-Union threshold for detection filtering.\n conf (float): Confidence threshold for keeping predictions.\n device (str, optional): Device to run inference on (e.g., 'cpu', '0' for CUDA GPU).\n max_det (int): Maximum number of detections allowed per video frame.\n half (bool): Whether to use FP16 precision (requires a supported CUDA device).\n tracker (str): Path to tracking configuration YAML file (e.g., 'botsort.yaml').\n verbose (bool): Enable verbose logging output for debugging or diagnostics.\n data (str): Path to image directory used for similarity search.\n\n Methods:\n update: Update the configuration with user-defined keyword arguments and raise error on invalid keys.\n\n Examples:\n >>> from ultralytics.solutions.config import SolutionConfig\n >>> cfg = SolutionConfig(model=\"yolo11n.pt\", region=[(0, 0), (100, 0), (100, 100), (0, 100)])\n >>> cfg.update(show=False, conf=0.3)\n >>> print(cfg.model)\n \"\"\"\n\n source: Optional[str] = None\n model: Optional[str] = None\n classes: Optional[List[int]] = None\n show_conf: bool = True\n show_labels: bool = True\n region: Optional[List[Tuple[int, int]]] = None\n colormap: Optional[int] = cv2.COLORMAP_DEEPGREEN\n show_in: bool = True\n show_out: bool = True\n up_angle: float = 145.0\n down_angle: int = 90\n kpts: List[int] = field(default_factory=lambda: [6, 8, 10])\n analytics_type: str = \"line\"\n figsize: Optional[Tuple[int, int]] = (12.8, 7.2)\n blur_ratio: float = 0.5\n vision_point: Tuple[int, int] = (20, 20)\n crop_dir: str = \"cropped-detections\"\n json_file: str = None\n line_width: int = 2\n records: int = 5\n fps: float = 30.0\n max_hist: int = 5\n meter_per_pixel: float = 0.05\n max_speed: int = 120\n show: bool = False\n iou: float = 0.7\n conf: float = 0.25\n device: Optional[str] = None\n max_det: int = 300\n half: bool = False\n tracker: str = \"botsort.yaml\"\n verbose: bool = True\n data: str = \"images\"\n\n def update(self, **kwargs: Any):\n \"\"\"Update configuration parameters with new values provided as keyword arguments.\"\"\"\n for key, value in kwargs.items():\n if hasattr(self, key):\n setattr(self, key, value)\n else:\n url = \"https://docs.ultralytics.com/solutions/#solutions-arguments\"\n raise ValueError(f\"{key} is not a valid solution argument, see {url}\")\n\n return self", "chunk_type": "class", "name": "SolutionConfig", "file_path": "ultralytics\\ultralytics\\solutions\\config.py", "start_line": 10, "end_line": 106, "start_col": 0, "end_col": 19, "parent_name": null, "docstring": "Manages configuration parameters for Ultralytics Vision AI solutions.\n\nThe SolutionConfig class serves as a centralized configuration container for all the\nUltralytics solution modules: https://docs.ultralytics.com/solutions/#solutions.\nIt leverages Python `dataclass` for clear, type-safe, and maintainable parameter definitions.\n\nAttributes:\n source (str, optional): Path to the input source (video, RTSP, etc.). Only usable with Solutions CLI.\n model (str, optional): Path to the Ultralytics YOLO model to be used for inference.\n classes (List[int], optional): List of class indices to filter detections.\n show_conf (bool): Whether to show confidence scores on the visual output.\n show_labels (bool): Whether to display class labels on visual output.\n region (List[Tuple[int, int]], optional): Polygonal region or line for object counting.\n colormap (int, optional): OpenCV colormap constant for visual overlays (e.g., cv2.COLORMAP_JET).\n show_in (bool): Whether to display count number for objects entering the region.\n show_out (bool): Whether to display count number for objects leaving the region.\n up_angle (float): Upper angle threshold used in pose-based workouts monitoring.\n down_angle (int): Lower angle threshold used in pose-based workouts monitoring.\n kpts (List[int]): Keypoint indices to monitor, e.g., for pose analytics.\n analytics_type (str): Type of analytics to perform (\"line\", \"area\", \"bar\", \"pie\", etc.).\n figsize (Tuple[int, int], optional): Size of the matplotlib figure used for analytical plots (width, height).\n blur_ratio (float): Ratio used to blur objects in the video frames (0.0 to 1.0).\n vision_point (Tuple[int, int]): Reference point for directional tracking or perspective drawing.\n crop_dir (str): Directory path to save cropped detection images.\n json_file (str): Path to a JSON file containing data for parking areas.\n line_width (int): Width for visual display i.e. bounding boxes, keypoints, counts.\n records (int): Number of detection records to send email alerts.\n fps (float): Frame rate (Frames Per Second) for speed estimation calculation.\n max_hist (int): Maximum number of historical points or states stored per tracked object for speed estimation.\n meter_per_pixel (float): Scale for real-world measurement, used in speed or distance calculations.\n max_speed (int): Maximum speed limit (e.g., km/h or mph) used in visual alerts or constraints.\n show (bool): Whether to display the visual output on screen.\n iou (float): Intersection-over-Union threshold for detection filtering.\n conf (float): Confidence threshold for keeping predictions.\n device (str, optional): Device to run inference on (e.g., 'cpu', '0' for CUDA GPU).\n max_det (int): Maximum number of detections allowed per video frame.\n half (bool): Whether to use FP16 precision (requires a supported CUDA device).\n tracker (str): Path to tracking configuration YAML file (e.g., 'botsort.yaml').\n verbose (bool): Enable verbose logging output for debugging or diagnostics.\n data (str): Path to image directory used for similarity search.\n\nMethods:\n update: Update the configuration with user-defined keyword arguments and raise error on invalid keys.\n\nExamples:\n >>> from ultralytics.solutions.config import SolutionConfig\n >>> cfg = SolutionConfig(model=\"yolo11n.pt\", region=[(0, 0), (100, 0), (100, 100), (0, 100)])\n >>> cfg.update(show=False, conf=0.3)\n >>> print(cfg.model)", "parameters": null, "return_type": null, "decorators": [ "dataclass" ], "complexity_score": null, "dependencies": [ "dataclasses.dataclass", "dataclasses.field", "typing.Any", "typing.List", "typing.Optional", "typing.Tuple", "cv2" ], "chunk_id": "class_SolutionConfig_738541d8" }, { "content": "import math", "chunk_type": "import", "name": "math", "file_path": "ultralytics\\ultralytics\\solutions\\distance_calculation.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 11, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_math_f0eb5c0c" }, { "content": "from typing import Any, Dict, List", "chunk_type": "import", "name": "Any, Dict, List", "file_path": "ultralytics\\ultralytics\\solutions\\distance_calculation.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 34, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Any, Dict, List_d50efa15" }, { "content": "import cv2", "chunk_type": "import", "name": "cv2", "file_path": "ultralytics\\ultralytics\\solutions\\distance_calculation.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 10, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_cv2_a747dc7a" }, { "content": "from ultralytics.solutions.solutions import BaseSolution, SolutionAnnotator, SolutionResults", "chunk_type": "import", "name": "BaseSolution, SolutionAnnotator, SolutionResults", "file_path": "ultralytics\\ultralytics\\solutions\\distance_calculation.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 92, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_BaseSolution, SolutionAnnotator, SolutionResults_8dc910d2" }, { "content": "from ultralytics.utils.plotting import colors", "chunk_type": "import", "name": "colors", "file_path": "ultralytics\\ultralytics\\solutions\\distance_calculation.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 45, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_colors_5744a72c" }, { "content": "class DistanceCalculation(BaseSolution):\n \"\"\"\n A class to calculate distance between two objects in a real-time video stream based on their tracks.\n\n This class extends BaseSolution to provide functionality for selecting objects and calculating the distance\n between them in a video stream using YOLO object detection and tracking.\n\n Attributes:\n left_mouse_count (int): Counter for left mouse button clicks.\n selected_boxes (Dict[int, List[float]]): Dictionary to store selected bounding boxes and their track IDs.\n centroids (List[List[int]]): List to store centroids of selected bounding boxes.\n\n Methods:\n mouse_event_for_distance: Handle mouse events for selecting objects in the video stream.\n process: Process video frames and calculate the distance between selected objects.\n\n Examples:\n >>> distance_calc = DistanceCalculation()\n >>> frame = cv2.imread(\"frame.jpg\")\n >>> results = distance_calc.process(frame)\n >>> cv2.imshow(\"Distance Calculation\", results.plot_im)\n >>> cv2.waitKey(0)\n \"\"\"\n\n def __init__(self, **kwargs: Any) -> None:\n \"\"\"Initialize the DistanceCalculation class for measuring object distances in video streams.\"\"\"\n super().__init__(**kwargs)\n\n # Mouse event information\n self.left_mouse_count = 0\n self.selected_boxes: Dict[int, List[float]] = {}\n self.centroids: List[List[int]] = [] # Store centroids of selected objects\n\n def mouse_event_for_distance(self, event: int, x: int, y: int, flags: int, param: Any) -> None:\n \"\"\"\n Handle mouse events to select regions in a real-time video stream for distance calculation.\n\n Args:\n event (int): Type of mouse event (e.g., cv2.EVENT_MOUSEMOVE, cv2.EVENT_LBUTTONDOWN).\n x (int): X-coordinate of the mouse pointer.\n y (int): Y-coordinate of the mouse pointer.\n flags (int): Flags associated with the event (e.g., cv2.EVENT_FLAG_CTRLKEY, cv2.EVENT_FLAG_SHIFTKEY).\n param (Any): Additional parameters passed to the function.\n\n Examples:\n >>> # Assuming 'dc' is an instance of DistanceCalculation\n >>> cv2.setMouseCallback(\"window_name\", dc.mouse_event_for_distance)\n \"\"\"\n if event == cv2.EVENT_LBUTTONDOWN:\n self.left_mouse_count += 1\n if self.left_mouse_count <= 2:\n for box, track_id in zip(self.boxes, self.track_ids):\n if box[0] < x < box[2] and box[1] < y < box[3] and track_id not in self.selected_boxes:\n self.selected_boxes[track_id] = box\n\n elif event == cv2.EVENT_RBUTTONDOWN:\n self.selected_boxes = {}\n self.left_mouse_count = 0\n\n def process(self, im0) -> SolutionResults:\n \"\"\"\n Process a video frame and calculate the distance between two selected bounding boxes.\n\n This method extracts tracks from the input frame, annotates bounding boxes, and calculates the distance\n between two user-selected objects if they have been chosen.\n\n Args:\n im0 (np.ndarray): The input image frame to process.\n\n Returns:\n (SolutionResults): Contains processed image `plot_im`, `total_tracks` (int) representing the total number\n of tracked objects, and `pixels_distance` (float) representing the distance between selected objects\n in pixels.\n\n Examples:\n >>> import numpy as np\n >>> from ultralytics.solutions import DistanceCalculation\n >>> dc = DistanceCalculation()\n >>> frame = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8)\n >>> results = dc.process(frame)\n >>> print(f\"Distance: {results.pixels_distance:.2f} pixels\")\n \"\"\"\n self.extract_tracks(im0) # Extract tracks\n annotator = SolutionAnnotator(im0, line_width=self.line_width) # Initialize annotator\n\n pixels_distance = 0\n # Iterate over bounding boxes, track ids and classes index\n for box, track_id, cls, conf in zip(self.boxes, self.track_ids, self.clss, self.confs):\n annotator.box_label(box, color=colors(int(cls), True), label=self.adjust_box_label(cls, conf, track_id))\n\n # Update selected boxes if they're being tracked\n if len(self.selected_boxes) == 2:\n for trk_id in self.selected_boxes.keys():\n if trk_id == track_id:\n self.selected_boxes[track_id] = box\n\n if len(self.selected_boxes) == 2:\n # Calculate centroids of selected boxes\n self.centroids.extend(\n [[int((box[0] + box[2]) // 2), int((box[1] + box[3]) // 2)] for box in self.selected_boxes.values()]\n )\n # Calculate Euclidean distance between centroids\n pixels_distance = math.sqrt(\n (self.centroids[0][0] - self.centroids[1][0]) ** 2 + (self.centroids[0][1] - self.centroids[1][1]) ** 2\n )\n annotator.plot_distance_and_line(pixels_distance, self.centroids)\n\n self.centroids = [] # Reset centroids for next frame\n plot_im = annotator.result()\n self.display_output(plot_im) # Display output with base class function\n if self.CFG.get(\"show\") and self.env_check:\n cv2.setMouseCallback(\"Ultralytics Solutions\", self.mouse_event_for_distance)\n\n # Return SolutionResults with processed image and calculated metrics\n return SolutionResults(plot_im=plot_im, pixels_distance=pixels_distance, total_tracks=len(self.track_ids))", "chunk_type": "class", "name": "DistanceCalculation", "file_path": "ultralytics\\ultralytics\\solutions\\distance_calculation.py", "start_line": 12, "end_line": 126, "start_col": 0, "end_col": 114, "parent_name": null, "docstring": "A class to calculate distance between two objects in a real-time video stream based on their tracks.\n\nThis class extends BaseSolution to provide functionality for selecting objects and calculating the distance\nbetween them in a video stream using YOLO object detection and tracking.\n\nAttributes:\n left_mouse_count (int): Counter for left mouse button clicks.\n selected_boxes (Dict[int, List[float]]): Dictionary to store selected bounding boxes and their track IDs.\n centroids (List[List[int]]): List to store centroids of selected bounding boxes.\n\nMethods:\n mouse_event_for_distance: Handle mouse events for selecting objects in the video stream.\n process: Process video frames and calculate the distance between selected objects.\n\nExamples:\n >>> distance_calc = DistanceCalculation()\n >>> frame = cv2.imread(\"frame.jpg\")\n >>> results = distance_calc.process(frame)\n >>> cv2.imshow(\"Distance Calculation\", results.plot_im)\n >>> cv2.waitKey(0)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "math", "typing.Any", "typing.Dict", "typing.List", "cv2", "ultralytics.solutions.solutions.BaseSolution", "ultralytics.solutions.solutions.SolutionAnnotator", "ultralytics.solutions.solutions.SolutionResults", "ultralytics.utils.plotting.colors", "BaseSolution" ], "chunk_id": "class_DistanceCalculation_e41e10e7" }, { "content": "from typing import Any, List", "chunk_type": "import", "name": "Any, List", "file_path": "ultralytics\\ultralytics\\solutions\\heatmap.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 28, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Any, List_a2cd2ab8" }, { "content": "import cv2", "chunk_type": "import", "name": "cv2", "file_path": "ultralytics\\ultralytics\\solutions\\heatmap.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 10, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_cv2_918b2bbe" }, { "content": "import numpy as np", "chunk_type": "import", "name": "numpy", "file_path": "ultralytics\\ultralytics\\solutions\\heatmap.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_numpy_a22c1c8c" }, { "content": "from ultralytics.solutions.object_counter import ObjectCounter", "chunk_type": "import", "name": "ObjectCounter", "file_path": "ultralytics\\ultralytics\\solutions\\heatmap.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 62, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_ObjectCounter_d18c36b5" }, { "content": "from ultralytics.solutions.solutions import SolutionAnnotator, SolutionResults", "chunk_type": "import", "name": "SolutionAnnotator, SolutionResults", "file_path": "ultralytics\\ultralytics\\solutions\\heatmap.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 78, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_SolutionAnnotator, SolutionResults_8cbf8f72" }, { "content": "class Heatmap(ObjectCounter):\n \"\"\"\n A class to draw heatmaps in real-time video streams based on object tracks.\n\n This class extends the ObjectCounter class to generate and visualize heatmaps of object movements in video\n streams. It uses tracked object positions to create a cumulative heatmap effect over time.\n\n Attributes:\n initialized (bool): Flag indicating whether the heatmap has been initialized.\n colormap (int): OpenCV colormap used for heatmap visualization.\n heatmap (np.ndarray): Array storing the cumulative heatmap data.\n annotator (SolutionAnnotator): Object for drawing annotations on the image.\n\n Methods:\n heatmap_effect: Calculate and update the heatmap effect for a given bounding box.\n process: Generate and apply the heatmap effect to each frame.\n\n Examples:\n >>> from ultralytics.solutions import Heatmap\n >>> heatmap = Heatmap(model=\"yolo11n.pt\", colormap=cv2.COLORMAP_JET)\n >>> frame = cv2.imread(\"frame.jpg\")\n >>> processed_frame = heatmap.process(frame)\n \"\"\"\n\n def __init__(self, **kwargs: Any) -> None:\n \"\"\"\n Initialize the Heatmap class for real-time video stream heatmap generation based on object tracks.\n\n Args:\n **kwargs (Any): Keyword arguments passed to the parent ObjectCounter class.\n \"\"\"\n super().__init__(**kwargs)\n\n self.initialized = False # Flag for heatmap initialization\n if self.region is not None: # Check if user provided the region coordinates\n self.initialize_region()\n\n # Store colormap\n self.colormap = self.CFG[\"colormap\"]\n self.heatmap = None\n\n def heatmap_effect(self, box: List[float]) -> None:\n \"\"\"\n Efficiently calculate heatmap area and effect location for applying colormap.\n\n Args:\n box (List[float]): Bounding box coordinates [x0, y0, x1, y1].\n \"\"\"\n x0, y0, x1, y1 = map(int, box)\n radius_squared = (min(x1 - x0, y1 - y0) // 2) ** 2\n\n # Create a meshgrid with region of interest (ROI) for vectorized distance calculations\n xv, yv = np.meshgrid(np.arange(x0, x1), np.arange(y0, y1))\n\n # Calculate squared distances from the center\n dist_squared = (xv - ((x0 + x1) // 2)) ** 2 + (yv - ((y0 + y1) // 2)) ** 2\n\n # Create a mask of points within the radius\n within_radius = dist_squared <= radius_squared\n\n # Update only the values within the bounding box in a single vectorized operation\n self.heatmap[y0:y1, x0:x1][within_radius] += 2\n\n def process(self, im0: np.ndarray) -> SolutionResults:\n \"\"\"\n Generate heatmap for each frame using Ultralytics tracking.\n\n Args:\n im0 (np.ndarray): Input image array for processing.\n\n Returns:\n (SolutionResults): Contains processed image `plot_im`,\n 'in_count' (int, count of objects entering the region),\n 'out_count' (int, count of objects exiting the region),\n 'classwise_count' (dict, per-class object count), and\n 'total_tracks' (int, total number of tracked objects).\n \"\"\"\n if not self.initialized:\n self.heatmap = np.zeros_like(im0, dtype=np.float32) * 0.99\n self.initialized = True # Initialize heatmap only once\n\n self.extract_tracks(im0) # Extract tracks\n self.annotator = SolutionAnnotator(im0, line_width=self.line_width) # Initialize annotator\n\n # Iterate over bounding boxes, track ids and classes index\n for box, track_id, cls in zip(self.boxes, self.track_ids, self.clss):\n # Apply heatmap effect for the bounding box\n self.heatmap_effect(box)\n\n if self.region is not None:\n self.annotator.draw_region(reg_pts=self.region, color=(104, 0, 123), thickness=self.line_width * 2)\n self.store_tracking_history(track_id, box) # Store track history\n # Get previous position if available\n prev_position = None\n if len(self.track_history[track_id]) > 1:\n prev_position = self.track_history[track_id][-2]\n self.count_objects(self.track_history[track_id][-1], track_id, prev_position, cls) # object counting\n\n plot_im = self.annotator.result()\n if self.region is not None:\n self.display_counts(plot_im) # Display the counts on the frame\n\n # Normalize, apply colormap to heatmap and combine with original image\n if self.track_data.is_track:\n normalized_heatmap = cv2.normalize(self.heatmap, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)\n colored_heatmap = cv2.applyColorMap(normalized_heatmap, self.colormap)\n plot_im = cv2.addWeighted(plot_im, 0.5, colored_heatmap, 0.5, 0)\n\n self.display_output(plot_im) # Display output with base class function\n\n # Return SolutionResults\n return SolutionResults(\n plot_im=plot_im,\n in_count=self.in_count,\n out_count=self.out_count,\n classwise_count=dict(self.classwise_count),\n total_tracks=len(self.track_ids),\n )", "chunk_type": "class", "name": "Heatmap", "file_path": "ultralytics\\ultralytics\\solutions\\heatmap.py", "start_line": 12, "end_line": 129, "start_col": 0, "end_col": 9, "parent_name": null, "docstring": "A class to draw heatmaps in real-time video streams based on object tracks.\n\nThis class extends the ObjectCounter class to generate and visualize heatmaps of object movements in video\nstreams. It uses tracked object positions to create a cumulative heatmap effect over time.\n\nAttributes:\n initialized (bool): Flag indicating whether the heatmap has been initialized.\n colormap (int): OpenCV colormap used for heatmap visualization.\n heatmap (np.ndarray): Array storing the cumulative heatmap data.\n annotator (SolutionAnnotator): Object for drawing annotations on the image.\n\nMethods:\n heatmap_effect: Calculate and update the heatmap effect for a given bounding box.\n process: Generate and apply the heatmap effect to each frame.\n\nExamples:\n >>> from ultralytics.solutions import Heatmap\n >>> heatmap = Heatmap(model=\"yolo11n.pt\", colormap=cv2.COLORMAP_JET)\n >>> frame = cv2.imread(\"frame.jpg\")\n >>> processed_frame = heatmap.process(frame)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.Any", "typing.List", "cv2", "numpy", "ultralytics.solutions.object_counter.ObjectCounter", "ultralytics.solutions.solutions.SolutionAnnotator", "ultralytics.solutions.solutions.SolutionResults", "ObjectCounter" ], "chunk_id": "class_Heatmap_389b1fa4" }, { "content": "from typing import Any", "chunk_type": "import", "name": "Any", "file_path": "ultralytics\\ultralytics\\solutions\\instance_segmentation.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 22, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Any_b16d738f" }, { "content": "from ultralytics.engine.results import Results", "chunk_type": "import", "name": "Results", "file_path": "ultralytics\\ultralytics\\solutions\\instance_segmentation.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 46, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Results_c46d4b51" }, { "content": "from ultralytics.solutions.solutions import BaseSolution, SolutionResults", "chunk_type": "import", "name": "BaseSolution, SolutionResults", "file_path": "ultralytics\\ultralytics\\solutions\\instance_segmentation.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 73, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_BaseSolution, SolutionResults_e1acea62" }, { "content": "class InstanceSegmentation(BaseSolution):\n \"\"\"\n A class to manage instance segmentation in images or video streams.\n\n This class extends the BaseSolution class and provides functionality for performing instance segmentation, including\n drawing segmented masks with bounding boxes and labels.\n\n Attributes:\n model (str): The segmentation model to use for inference.\n line_width (int): Width of the bounding box and text lines.\n names (Dict[int, str]): Dictionary mapping class indices to class names.\n clss (List[int]): List of detected class indices.\n track_ids (List[int]): List of track IDs for detected instances.\n masks (List[np.ndarray]): List of segmentation masks for detected instances.\n show_conf (bool): Whether to display confidence scores.\n show_labels (bool): Whether to display class labels.\n show_boxes (bool): Whether to display bounding boxes.\n\n Methods:\n process: Process the input image to perform instance segmentation and annotate results.\n extract_tracks: Extract tracks including bounding boxes, classes, and masks from model predictions.\n\n Examples:\n >>> segmenter = InstanceSegmentation()\n >>> frame = cv2.imread(\"frame.jpg\")\n >>> results = segmenter.process(frame)\n >>> print(f\"Total segmented instances: {results.total_tracks}\")\n \"\"\"\n\n def __init__(self, **kwargs: Any) -> None:\n \"\"\"\n Initialize the InstanceSegmentation class for detecting and annotating segmented instances.\n\n Args:\n **kwargs (Any): Keyword arguments passed to the BaseSolution parent class.\n model (str): Model name or path, defaults to \"yolo11n-seg.pt\".\n \"\"\"\n kwargs[\"model\"] = kwargs.get(\"model\", \"yolo11n-seg.pt\")\n super().__init__(**kwargs)\n\n self.show_conf = self.CFG.get(\"show_conf\", True)\n self.show_labels = self.CFG.get(\"show_labels\", True)\n self.show_boxes = self.CFG.get(\"show_boxes\", True)\n\n def process(self, im0) -> SolutionResults:\n \"\"\"\n Perform instance segmentation on the input image and annotate the results.\n\n Args:\n im0 (np.ndarray): The input image for segmentation.\n\n Returns:\n (SolutionResults): Object containing the annotated image and total number of tracked instances.\n\n Examples:\n >>> segmenter = InstanceSegmentation()\n >>> frame = cv2.imread(\"image.jpg\")\n >>> summary = segmenter.process(frame)\n >>> print(summary)\n \"\"\"\n self.extract_tracks(im0) # Extract tracks (bounding boxes, classes, and masks)\n self.masks = getattr(self.tracks, \"masks\", None)\n\n # Iterate over detected classes, track IDs, and segmentation masks\n if self.masks is None:\n self.LOGGER.warning(\"No masks detected! Ensure you're using a supported Ultralytics segmentation model.\")\n plot_im = im0\n else:\n results = Results(im0, path=None, names=self.names, boxes=self.track_data.data, masks=self.masks.data)\n plot_im = results.plot(\n line_width=self.line_width,\n boxes=self.show_boxes,\n conf=self.show_conf,\n labels=self.show_labels,\n color_mode=\"instance\",\n )\n\n self.display_output(plot_im) # Display the annotated output using the base class function\n\n # Return SolutionResults\n return SolutionResults(plot_im=plot_im, total_tracks=len(self.track_ids))", "chunk_type": "class", "name": "InstanceSegmentation", "file_path": "ultralytics\\ultralytics\\solutions\\instance_segmentation.py", "start_line": 9, "end_line": 89, "start_col": 0, "end_col": 81, "parent_name": null, "docstring": "A class to manage instance segmentation in images or video streams.\n\nThis class extends the BaseSolution class and provides functionality for performing instance segmentation, including\ndrawing segmented masks with bounding boxes and labels.\n\nAttributes:\n model (str): The segmentation model to use for inference.\n line_width (int): Width of the bounding box and text lines.\n names (Dict[int, str]): Dictionary mapping class indices to class names.\n clss (List[int]): List of detected class indices.\n track_ids (List[int]): List of track IDs for detected instances.\n masks (List[np.ndarray]): List of segmentation masks for detected instances.\n show_conf (bool): Whether to display confidence scores.\n show_labels (bool): Whether to display class labels.\n show_boxes (bool): Whether to display bounding boxes.\n\nMethods:\n process: Process the input image to perform instance segmentation and annotate results.\n extract_tracks: Extract tracks including bounding boxes, classes, and masks from model predictions.\n\nExamples:\n >>> segmenter = InstanceSegmentation()\n >>> frame = cv2.imread(\"frame.jpg\")\n >>> results = segmenter.process(frame)\n >>> print(f\"Total segmented instances: {results.total_tracks}\")", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.Any", "ultralytics.engine.results.Results", "ultralytics.solutions.solutions.BaseSolution", "ultralytics.solutions.solutions.SolutionResults", "BaseSolution" ], "chunk_id": "class_InstanceSegmentation_98483f06" }, { "content": "from typing import Any", "chunk_type": "import", "name": "Any", "file_path": "ultralytics\\ultralytics\\solutions\\object_blurrer.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 22, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Any_c8a8d6c9" }, { "content": "import cv2", "chunk_type": "import", "name": "cv2", "file_path": "ultralytics\\ultralytics\\solutions\\object_blurrer.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 10, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_cv2_87fa0a42" }, { "content": "from ultralytics.solutions.solutions import BaseSolution, SolutionAnnotator, SolutionResults", "chunk_type": "import", "name": "BaseSolution, SolutionAnnotator, SolutionResults", "file_path": "ultralytics\\ultralytics\\solutions\\object_blurrer.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 92, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_BaseSolution, SolutionAnnotator, SolutionResults_a8aec2d5" }, { "content": "from ultralytics.utils import LOGGER", "chunk_type": "import", "name": "LOGGER", "file_path": "ultralytics\\ultralytics\\solutions\\object_blurrer.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 36, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_LOGGER_02fa0b4a" }, { "content": "from ultralytics.utils.plotting import colors", "chunk_type": "import", "name": "colors", "file_path": "ultralytics\\ultralytics\\solutions\\object_blurrer.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 45, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_colors_9bc31cba" }, { "content": "class ObjectBlurrer(BaseSolution):\n \"\"\"\n A class to manage the blurring of detected objects in a real-time video stream.\n\n This class extends the BaseSolution class and provides functionality for blurring objects based on detected bounding\n boxes. The blurred areas are updated directly in the input image, allowing for privacy preservation or other effects.\n\n Attributes:\n blur_ratio (int): The intensity of the blur effect applied to detected objects (higher values create more blur).\n iou (float): Intersection over Union threshold for object detection.\n conf (float): Confidence threshold for object detection.\n\n Methods:\n process: Apply a blurring effect to detected objects in the input image.\n extract_tracks: Extract tracking information from detected objects.\n display_output: Display the processed output image.\n\n Examples:\n >>> blurrer = ObjectBlurrer()\n >>> frame = cv2.imread(\"frame.jpg\")\n >>> processed_results = blurrer.process(frame)\n >>> print(f\"Total blurred objects: {processed_results.total_tracks}\")\n \"\"\"\n\n def __init__(self, **kwargs: Any) -> None:\n \"\"\"\n Initialize the ObjectBlurrer class for applying a blur effect to objects detected in video streams or images.\n\n Args:\n **kwargs (Any): Keyword arguments passed to the parent class and for configuration.\n blur_ratio (float): Intensity of the blur effect (0.1-1.0, default=0.5).\n \"\"\"\n super().__init__(**kwargs)\n blur_ratio = self.CFG[\"blur_ratio\"]\n if blur_ratio < 0.1:\n LOGGER.warning(\"blur ratio cannot be less than 0.1, updating it to default value 0.5\")\n blur_ratio = 0.5\n self.blur_ratio = int(blur_ratio * 100)\n\n def process(self, im0) -> SolutionResults:\n \"\"\"\n Apply a blurring effect to detected objects in the input image.\n\n This method extracts tracking information, applies blur to regions corresponding to detected objects,\n and annotates the image with bounding boxes.\n\n Args:\n im0 (np.ndarray): The input image containing detected objects.\n\n Returns:\n (SolutionResults): Object containing the processed image and number of tracked objects.\n - plot_im (np.ndarray): The annotated output image with blurred objects.\n - total_tracks (int): The total number of tracked objects in the frame.\n\n Examples:\n >>> blurrer = ObjectBlurrer()\n >>> frame = cv2.imread(\"image.jpg\")\n >>> results = blurrer.process(frame)\n >>> print(f\"Blurred {results.total_tracks} objects\")\n \"\"\"\n self.extract_tracks(im0) # Extract tracks\n annotator = SolutionAnnotator(im0, self.line_width)\n\n # Iterate over bounding boxes and classes\n for box, cls, conf in zip(self.boxes, self.clss, self.confs):\n # Crop and blur the detected object\n blur_obj = cv2.blur(\n im0[int(box[1]) : int(box[3]), int(box[0]) : int(box[2])],\n (self.blur_ratio, self.blur_ratio),\n )\n # Update the blurred area in the original image\n im0[int(box[1]) : int(box[3]), int(box[0]) : int(box[2])] = blur_obj\n annotator.box_label(\n box, label=self.adjust_box_label(cls, conf), color=colors(cls, True)\n ) # Annotate bounding box\n\n plot_im = annotator.result()\n self.display_output(plot_im) # Display the output using the base class function\n\n # Return a SolutionResults\n return SolutionResults(plot_im=plot_im, total_tracks=len(self.track_ids))", "chunk_type": "class", "name": "ObjectBlurrer", "file_path": "ultralytics\\ultralytics\\solutions\\object_blurrer.py", "start_line": 12, "end_line": 92, "start_col": 0, "end_col": 81, "parent_name": null, "docstring": "A class to manage the blurring of detected objects in a real-time video stream.\n\nThis class extends the BaseSolution class and provides functionality for blurring objects based on detected bounding\nboxes. The blurred areas are updated directly in the input image, allowing for privacy preservation or other effects.\n\nAttributes:\n blur_ratio (int): The intensity of the blur effect applied to detected objects (higher values create more blur).\n iou (float): Intersection over Union threshold for object detection.\n conf (float): Confidence threshold for object detection.\n\nMethods:\n process: Apply a blurring effect to detected objects in the input image.\n extract_tracks: Extract tracking information from detected objects.\n display_output: Display the processed output image.\n\nExamples:\n >>> blurrer = ObjectBlurrer()\n >>> frame = cv2.imread(\"frame.jpg\")\n >>> processed_results = blurrer.process(frame)\n >>> print(f\"Total blurred objects: {processed_results.total_tracks}\")", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.Any", "cv2", "ultralytics.solutions.solutions.BaseSolution", "ultralytics.solutions.solutions.SolutionAnnotator", "ultralytics.solutions.solutions.SolutionResults", "ultralytics.utils.LOGGER", "ultralytics.utils.plotting.colors", "BaseSolution" ], "chunk_id": "class_ObjectBlurrer_084e818d" }, { "content": "from collections import defaultdict", "chunk_type": "import", "name": "defaultdict", "file_path": "ultralytics\\ultralytics\\solutions\\object_counter.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 35, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_defaultdict_aa0edce9" }, { "content": "from typing import Any, Optional, Tuple", "chunk_type": "import", "name": "Any, Optional, Tuple", "file_path": "ultralytics\\ultralytics\\solutions\\object_counter.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 39, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Any, Optional, Tuple_499ca0de" }, { "content": "from ultralytics.solutions.solutions import BaseSolution, SolutionAnnotator, SolutionResults", "chunk_type": "import", "name": "BaseSolution, SolutionAnnotator, SolutionResults", "file_path": "ultralytics\\ultralytics\\solutions\\object_counter.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 92, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_BaseSolution, SolutionAnnotator, SolutionResults_3a6009b8" }, { "content": "from ultralytics.utils.plotting import colors", "chunk_type": "import", "name": "colors", "file_path": "ultralytics\\ultralytics\\solutions\\object_counter.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 45, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_colors_fa7a8efa" }, { "content": "class ObjectCounter(BaseSolution):\n \"\"\"\n A class to manage the counting of objects in a real-time video stream based on their tracks.\n\n This class extends the BaseSolution class and provides functionality for counting objects moving in and out of a\n specified region in a video stream. It supports both polygonal and linear regions for counting.\n\n Attributes:\n in_count (int): Counter for objects moving inward.\n out_count (int): Counter for objects moving outward.\n counted_ids (List[int]): List of IDs of objects that have been counted.\n classwise_counts (Dict[str, Dict[str, int]]): Dictionary for counts, categorized by object class.\n region_initialized (bool): Flag indicating whether the counting region has been initialized.\n show_in (bool): Flag to control display of inward count.\n show_out (bool): Flag to control display of outward count.\n margin (int): Margin for background rectangle size to display counts properly.\n\n Methods:\n count_objects: Count objects within a polygonal or linear region based on their tracks.\n display_counts: Display object counts on the frame.\n process: Process input data and update counts.\n\n Examples:\n >>> counter = ObjectCounter()\n >>> frame = cv2.imread(\"frame.jpg\")\n >>> results = counter.process(frame)\n >>> print(f\"Inward count: {counter.in_count}, Outward count: {counter.out_count}\")\n \"\"\"\n\n def __init__(self, **kwargs: Any) -> None:\n \"\"\"Initialize the ObjectCounter class for real-time object counting in video streams.\"\"\"\n super().__init__(**kwargs)\n\n self.in_count = 0 # Counter for objects moving inward\n self.out_count = 0 # Counter for objects moving outward\n self.counted_ids = [] # List of IDs of objects that have been counted\n self.classwise_count = defaultdict(lambda: {\"IN\": 0, \"OUT\": 0}) # Dictionary for counts, categorized by class\n self.region_initialized = False # Flag indicating whether the region has been initialized\n\n self.show_in = self.CFG[\"show_in\"]\n self.show_out = self.CFG[\"show_out\"]\n self.margin = self.line_width * 2 # Scales the background rectangle size to display counts properly\n\n def count_objects(\n self,\n current_centroid: Tuple[float, float],\n track_id: int,\n prev_position: Optional[Tuple[float, float]],\n cls: int,\n ) -> None:\n \"\"\"\n Count objects within a polygonal or linear region based on their tracks.\n\n Args:\n current_centroid (Tuple[float, float]): Current centroid coordinates (x, y) in the current frame.\n track_id (int): Unique identifier for the tracked object.\n prev_position (Tuple[float, float], optional): Last frame position coordinates (x, y) of the track.\n cls (int): Class index for classwise count updates.\n\n Examples:\n >>> counter = ObjectCounter()\n >>> track_line = {1: [100, 200], 2: [110, 210], 3: [120, 220]}\n >>> box = [130, 230, 150, 250]\n >>> track_id_num = 1\n >>> previous_position = (120, 220)\n >>> class_to_count = 0 # In COCO model, class 0 = person\n >>> counter.count_objects((140, 240), track_id_num, previous_position, class_to_count)\n \"\"\"\n if prev_position is None or track_id in self.counted_ids:\n return\n\n if len(self.region) == 2: # Linear region (defined as a line segment)\n if self.r_s.intersects(self.LineString([prev_position, current_centroid])):\n # Determine orientation of the region (vertical or horizontal)\n if abs(self.region[0][0] - self.region[1][0]) < abs(self.region[0][1] - self.region[1][1]):\n # Vertical region: Compare x-coordinates to determine direction\n if current_centroid[0] > prev_position[0]: # Moving right\n self.in_count += 1\n self.classwise_count[self.names[cls]][\"IN\"] += 1\n else: # Moving left\n self.out_count += 1\n self.classwise_count[self.names[cls]][\"OUT\"] += 1\n # Horizontal region: Compare y-coordinates to determine direction\n elif current_centroid[1] > prev_position[1]: # Moving downward\n self.in_count += 1\n self.classwise_count[self.names[cls]][\"IN\"] += 1\n else: # Moving upward\n self.out_count += 1\n self.classwise_count[self.names[cls]][\"OUT\"] += 1\n self.counted_ids.append(track_id)\n\n elif len(self.region) > 2: # Polygonal region\n if self.r_s.contains(self.Point(current_centroid)):\n # Determine motion direction for vertical or horizontal polygons\n region_width = max(p[0] for p in self.region) - min(p[0] for p in self.region)\n region_height = max(p[1] for p in self.region) - min(p[1] for p in self.region)\n\n if (\n region_width < region_height\n and current_centroid[0] > prev_position[0]\n or region_width >= region_height\n and current_centroid[1] > prev_position[1]\n ): # Moving right or downward\n self.in_count += 1\n self.classwise_count[self.names[cls]][\"IN\"] += 1\n else: # Moving left or upward\n self.out_count += 1\n self.classwise_count[self.names[cls]][\"OUT\"] += 1\n self.counted_ids.append(track_id)\n\n def display_counts(self, plot_im) -> None:\n \"\"\"\n Display object counts on the input image or frame.\n\n Args:\n plot_im (np.ndarray): The image or frame to display counts on.\n\n Examples:\n >>> counter = ObjectCounter()\n >>> frame = cv2.imread(\"image.jpg\")\n >>> counter.display_counts(frame)\n \"\"\"\n labels_dict = {\n str.capitalize(key): f\"{'IN ' + str(value['IN']) if self.show_in else ''} \"\n f\"{'OUT ' + str(value['OUT']) if self.show_out else ''}\".strip()\n for key, value in self.classwise_count.items()\n if value[\"IN\"] != 0 or value[\"OUT\"] != 0 and (self.show_in or self.show_out)\n }\n if labels_dict:\n self.annotator.display_analytics(plot_im, labels_dict, (104, 31, 17), (255, 255, 255), self.margin)\n\n def process(self, im0) -> SolutionResults:\n \"\"\"\n Process input data (frames or object tracks) and update object counts.\n\n This method initializes the counting region, extracts tracks, draws bounding boxes and regions, updates\n object counts, and displays the results on the input image.\n\n Args:\n im0 (np.ndarray): The input image or frame to be processed.\n\n Returns:\n (SolutionResults): Contains processed image `im0`, 'in_count' (int, count of objects entering the region),\n 'out_count' (int, count of objects exiting the region), 'classwise_count' (dict, per-class object count),\n and 'total_tracks' (int, total number of tracked objects).\n\n Examples:\n >>> counter = ObjectCounter()\n >>> frame = cv2.imread(\"path/to/image.jpg\")\n >>> results = counter.process(frame)\n \"\"\"\n if not self.region_initialized:\n self.initialize_region()\n self.region_initialized = True\n\n self.extract_tracks(im0) # Extract tracks\n self.annotator = SolutionAnnotator(im0, line_width=self.line_width) # Initialize annotator\n\n self.annotator.draw_region(\n reg_pts=self.region, color=(104, 0, 123), thickness=self.line_width * 2\n ) # Draw region\n\n # Iterate over bounding boxes, track ids and classes index\n for box, track_id, cls, conf in zip(self.boxes, self.track_ids, self.clss, self.confs):\n # Draw bounding box and counting region\n self.annotator.box_label(box, label=self.adjust_box_label(cls, conf, track_id), color=colors(cls, True))\n self.store_tracking_history(track_id, box) # Store track history\n\n # Store previous position of track for object counting\n prev_position = None\n if len(self.track_history[track_id]) > 1:\n prev_position = self.track_history[track_id][-2]\n self.count_objects(self.track_history[track_id][-1], track_id, prev_position, cls) # object counting\n\n plot_im = self.annotator.result()\n self.display_counts(plot_im) # Display the counts on the frame\n self.display_output(plot_im) # Display output with base class function\n\n # Return SolutionResults\n return SolutionResults(\n plot_im=plot_im,\n in_count=self.in_count,\n out_count=self.out_count,\n classwise_count=dict(self.classwise_count),\n total_tracks=len(self.track_ids),\n )", "chunk_type": "class", "name": "ObjectCounter", "file_path": "ultralytics\\ultralytics\\solutions\\object_counter.py", "start_line": 10, "end_line": 195, "start_col": 0, "end_col": 9, "parent_name": null, "docstring": "A class to manage the counting of objects in a real-time video stream based on their tracks.\n\nThis class extends the BaseSolution class and provides functionality for counting objects moving in and out of a\nspecified region in a video stream. It supports both polygonal and linear regions for counting.\n\nAttributes:\n in_count (int): Counter for objects moving inward.\n out_count (int): Counter for objects moving outward.\n counted_ids (List[int]): List of IDs of objects that have been counted.\n classwise_counts (Dict[str, Dict[str, int]]): Dictionary for counts, categorized by object class.\n region_initialized (bool): Flag indicating whether the counting region has been initialized.\n show_in (bool): Flag to control display of inward count.\n show_out (bool): Flag to control display of outward count.\n margin (int): Margin for background rectangle size to display counts properly.\n\nMethods:\n count_objects: Count objects within a polygonal or linear region based on their tracks.\n display_counts: Display object counts on the frame.\n process: Process input data and update counts.\n\nExamples:\n >>> counter = ObjectCounter()\n >>> frame = cv2.imread(\"frame.jpg\")\n >>> results = counter.process(frame)\n >>> print(f\"Inward count: {counter.in_count}, Outward count: {counter.out_count}\")", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "collections.defaultdict", "typing.Any", "typing.Optional", "typing.Tuple", "ultralytics.solutions.solutions.BaseSolution", "ultralytics.solutions.solutions.SolutionAnnotator", "ultralytics.solutions.solutions.SolutionResults", "ultralytics.utils.plotting.colors", "BaseSolution" ], "chunk_id": "class_ObjectCounter_d847be6c" }, { "content": "import os", "chunk_type": "import", "name": "os", "file_path": "ultralytics\\ultralytics\\solutions\\object_cropper.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 9, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_os_b1a33b92" }, { "content": "from pathlib import Path", "chunk_type": "import", "name": "Path", "file_path": "ultralytics\\ultralytics\\solutions\\object_cropper.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Path_51f22d2e" }, { "content": "from typing import Any", "chunk_type": "import", "name": "Any", "file_path": "ultralytics\\ultralytics\\solutions\\object_cropper.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 22, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Any_f5753439" }, { "content": "from ultralytics.solutions.solutions import BaseSolution, SolutionResults", "chunk_type": "import", "name": "BaseSolution, SolutionResults", "file_path": "ultralytics\\ultralytics\\solutions\\object_cropper.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 73, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_BaseSolution, SolutionResults_f8c45637" }, { "content": "from ultralytics.utils.plotting import save_one_box", "chunk_type": "import", "name": "save_one_box", "file_path": "ultralytics\\ultralytics\\solutions\\object_cropper.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 51, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_save_one_box_2b8e0ded" }, { "content": "class ObjectCropper(BaseSolution):\n \"\"\"\n A class to manage the cropping of detected objects in a real-time video stream or images.\n\n This class extends the BaseSolution class and provides functionality for cropping objects based on detected bounding\n boxes. The cropped images are saved to a specified directory for further analysis or usage.\n\n Attributes:\n crop_dir (str): Directory where cropped object images are stored.\n crop_idx (int): Counter for the total number of cropped objects.\n iou (float): IoU (Intersection over Union) threshold for non-maximum suppression.\n conf (float): Confidence threshold for filtering detections.\n\n Methods:\n process: Crop detected objects from the input image and save them to the output directory.\n\n Examples:\n >>> cropper = ObjectCropper()\n >>> frame = cv2.imread(\"frame.jpg\")\n >>> processed_results = cropper.process(frame)\n >>> print(f\"Total cropped objects: {cropper.crop_idx}\")\n \"\"\"\n\n def __init__(self, **kwargs: Any) -> None:\n \"\"\"\n Initialize the ObjectCropper class for cropping objects from detected bounding boxes.\n\n Args:\n **kwargs (Any): Keyword arguments passed to the parent class and used for configuration.\n crop_dir (str): Path to the directory for saving cropped object images.\n \"\"\"\n super().__init__(**kwargs)\n\n self.crop_dir = self.CFG[\"crop_dir\"] # Directory for storing cropped detections\n if not os.path.exists(self.crop_dir):\n os.mkdir(self.crop_dir) # Create directory if it does not exist\n if self.CFG[\"show\"]:\n self.LOGGER.warning(\n f\"show=True disabled for crop solution, results will be saved in the directory named: {self.crop_dir}\"\n )\n self.crop_idx = 0 # Initialize counter for total cropped objects\n self.iou = self.CFG[\"iou\"]\n self.conf = self.CFG[\"conf\"]\n\n def process(self, im0) -> SolutionResults:\n \"\"\"\n Crop detected objects from the input image and save them as separate images.\n\n Args:\n im0 (np.ndarray): The input image containing detected objects.\n\n Returns:\n (SolutionResults): A SolutionResults object containing the total number of cropped objects and processed\n image.\n\n Examples:\n >>> cropper = ObjectCropper()\n >>> frame = cv2.imread(\"image.jpg\")\n >>> results = cropper.process(frame)\n >>> print(f\"Total cropped objects: {results.total_crop_objects}\")\n \"\"\"\n with self.profilers[0]:\n results = self.model.predict(\n im0,\n classes=self.classes,\n conf=self.conf,\n iou=self.iou,\n device=self.CFG[\"device\"],\n verbose=False,\n )[0]\n\n for box in results.boxes:\n self.crop_idx += 1\n save_one_box(\n box.xyxy,\n im0,\n file=Path(self.crop_dir) / f\"crop_{self.crop_idx}.jpg\",\n BGR=True,\n )\n\n # Return SolutionResults\n return SolutionResults(plot_im=im0, total_crop_objects=self.crop_idx)", "chunk_type": "class", "name": "ObjectCropper", "file_path": "ultralytics\\ultralytics\\solutions\\object_cropper.py", "start_line": 11, "end_line": 92, "start_col": 0, "end_col": 77, "parent_name": null, "docstring": "A class to manage the cropping of detected objects in a real-time video stream or images.\n\nThis class extends the BaseSolution class and provides functionality for cropping objects based on detected bounding\nboxes. The cropped images are saved to a specified directory for further analysis or usage.\n\nAttributes:\n crop_dir (str): Directory where cropped object images are stored.\n crop_idx (int): Counter for the total number of cropped objects.\n iou (float): IoU (Intersection over Union) threshold for non-maximum suppression.\n conf (float): Confidence threshold for filtering detections.\n\nMethods:\n process: Crop detected objects from the input image and save them to the output directory.\n\nExamples:\n >>> cropper = ObjectCropper()\n >>> frame = cv2.imread(\"frame.jpg\")\n >>> processed_results = cropper.process(frame)\n >>> print(f\"Total cropped objects: {cropper.crop_idx}\")", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "os", "pathlib.Path", "typing.Any", "ultralytics.solutions.solutions.BaseSolution", "ultralytics.solutions.solutions.SolutionResults", "ultralytics.utils.plotting.save_one_box", "BaseSolution" ], "chunk_id": "class_ObjectCropper_997be9ab" }, { "content": "import json", "chunk_type": "import", "name": "json", "file_path": "ultralytics\\ultralytics\\solutions\\parking_management.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 11, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_json_dec60f93" }, { "content": "from typing import Any, List, Tuple", "chunk_type": "import", "name": "Any, List, Tuple", "file_path": "ultralytics\\ultralytics\\solutions\\parking_management.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 35, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Any, List, Tuple_b87f6f42" }, { "content": "import cv2", "chunk_type": "import", "name": "cv2", "file_path": "ultralytics\\ultralytics\\solutions\\parking_management.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 10, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_cv2_a06d28c6" }, { "content": "import numpy as np", "chunk_type": "import", "name": "numpy", "file_path": "ultralytics\\ultralytics\\solutions\\parking_management.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_numpy_45f95ec7" }, { "content": "from ultralytics.solutions.solutions import BaseSolution, SolutionAnnotator, SolutionResults", "chunk_type": "import", "name": "BaseSolution, SolutionAnnotator, SolutionResults", "file_path": "ultralytics\\ultralytics\\solutions\\parking_management.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 92, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_BaseSolution, SolutionAnnotator, SolutionResults_5de38a3f" }, { "content": "from ultralytics.utils import LOGGER", "chunk_type": "import", "name": "LOGGER", "file_path": "ultralytics\\ultralytics\\solutions\\parking_management.py", "start_line": 10, "end_line": 10, "start_col": 0, "end_col": 36, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_LOGGER_a0956800" }, { "content": "from ultralytics.utils.checks import check_imshow", "chunk_type": "import", "name": "check_imshow", "file_path": "ultralytics\\ultralytics\\solutions\\parking_management.py", "start_line": 11, "end_line": 11, "start_col": 0, "end_col": 49, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_check_imshow_9861f4b4" }, { "content": "class ParkingPtsSelection:\n \"\"\"\n A class for selecting and managing parking zone points on images using a Tkinter-based UI.\n\n This class provides functionality to upload an image, select points to define parking zones, and save the\n selected points to a JSON file. It uses Tkinter for the graphical user interface.\n\n Attributes:\n tk (module): The Tkinter module for GUI operations.\n filedialog (module): Tkinter's filedialog module for file selection operations.\n messagebox (module): Tkinter's messagebox module for displaying message boxes.\n master (tk.Tk): The main Tkinter window.\n canvas (tk.Canvas): The canvas widget for displaying the image and drawing bounding boxes.\n image (PIL.Image.Image): The uploaded image.\n canvas_image (ImageTk.PhotoImage): The image displayed on the canvas.\n rg_data (List[List[Tuple[int, int]]]): List of bounding boxes, each defined by 4 points.\n current_box (List[Tuple[int, int]]): Temporary storage for the points of the current bounding box.\n imgw (int): Original width of the uploaded image.\n imgh (int): Original height of the uploaded image.\n canvas_max_width (int): Maximum width of the canvas.\n canvas_max_height (int): Maximum height of the canvas.\n\n Methods:\n initialize_properties: Initialize properties for image, canvas, bounding boxes, and dimensions.\n upload_image: Upload and display an image on the canvas, resizing it to fit within specified dimensions.\n on_canvas_click: Handle mouse clicks to add points for bounding boxes on the canvas.\n draw_box: Draw a bounding box on the canvas using the provided coordinates.\n remove_last_bounding_box: Remove the last bounding box from the list and redraw the canvas.\n redraw_canvas: Redraw the canvas with the image and all bounding boxes.\n save_to_json: Save the selected parking zone points to a JSON file with scaled coordinates.\n\n Examples:\n >>> parking_selector = ParkingPtsSelection()\n >>> # Use the GUI to upload an image, select parking zones, and save the data\n \"\"\"\n\n def __init__(self) -> None:\n \"\"\"Initialize the ParkingPtsSelection class, setting up UI and properties for parking zone point selection.\"\"\"\n try: # Check if tkinter is installed\n import tkinter as tk\n from tkinter import filedialog, messagebox\n except ImportError: # Display error with recommendations\n import platform\n\n install_cmd = {\n \"Linux\": \"sudo apt install python3-tk (Debian/Ubuntu) | sudo dnf install python3-tkinter (Fedora) | \"\n \"sudo pacman -S tk (Arch)\",\n \"Windows\": \"reinstall Python and enable the checkbox `tcl/tk and IDLE` on **Optional Features** during installation\",\n \"Darwin\": \"reinstall Python from https://www.python.org/downloads/macos/ or `brew install python-tk`\",\n }.get(platform.system(), \"Unknown OS. Check your Python installation.\")\n\n LOGGER.warning(f\" Tkinter is not configured or supported. Potential fix: {install_cmd}\")\n return\n\n if not check_imshow(warn=True):\n return\n\n self.tk, self.filedialog, self.messagebox = tk, filedialog, messagebox\n self.master = self.tk.Tk() # Reference to the main application window\n self.master.title(\"Ultralytics Parking Zones Points Selector\")\n self.master.resizable(False, False)\n\n self.canvas = self.tk.Canvas(self.master, bg=\"white\") # Canvas widget for displaying images\n self.canvas.pack(side=self.tk.BOTTOM)\n\n self.image = None # Variable to store the loaded image\n self.canvas_image = None # Reference to the image displayed on the canvas\n self.canvas_max_width = None # Maximum allowed width for the canvas\n self.canvas_max_height = None # Maximum allowed height for the canvas\n self.rg_data = None # Data for region annotation management\n self.current_box = None # Stores the currently selected bounding box\n self.imgh = None # Height of the current image\n self.imgw = None # Width of the current image\n\n # Button frame with buttons\n button_frame = self.tk.Frame(self.master)\n button_frame.pack(side=self.tk.TOP)\n\n for text, cmd in [\n (\"Upload Image\", self.upload_image),\n (\"Remove Last BBox\", self.remove_last_bounding_box),\n (\"Save\", self.save_to_json),\n ]:\n self.tk.Button(button_frame, text=text, command=cmd).pack(side=self.tk.LEFT)\n\n self.initialize_properties()\n self.master.mainloop()\n\n def initialize_properties(self) -> None:\n \"\"\"Initialize properties for image, canvas, bounding boxes, and dimensions.\"\"\"\n self.image = self.canvas_image = None\n self.rg_data, self.current_box = [], []\n self.imgw = self.imgh = 0\n self.canvas_max_width, self.canvas_max_height = 1280, 720\n\n def upload_image(self) -> None:\n \"\"\"Upload and display an image on the canvas, resizing it to fit within specified dimensions.\"\"\"\n from PIL import Image, ImageTk # Scoped import because ImageTk requires tkinter package\n\n file = self.filedialog.askopenfilename(filetypes=[(\"Image Files\", \"*.png *.jpg *.jpeg\")])\n if not file:\n LOGGER.info(\"No image selected.\")\n return\n\n self.image = Image.open(file)\n self.imgw, self.imgh = self.image.size\n aspect_ratio = self.imgw / self.imgh\n canvas_width = (\n min(self.canvas_max_width, self.imgw) if aspect_ratio > 1 else int(self.canvas_max_height * aspect_ratio)\n )\n canvas_height = (\n min(self.canvas_max_height, self.imgh) if aspect_ratio <= 1 else int(canvas_width / aspect_ratio)\n )\n\n self.canvas.config(width=canvas_width, height=canvas_height)\n self.canvas_image = ImageTk.PhotoImage(self.image.resize((canvas_width, canvas_height)))\n self.canvas.create_image(0, 0, anchor=self.tk.NW, image=self.canvas_image)\n self.canvas.bind(\"\", self.on_canvas_click)\n\n self.rg_data.clear(), self.current_box.clear()\n\n def on_canvas_click(self, event) -> None:\n \"\"\"Handle mouse clicks to add points for bounding boxes on the canvas.\"\"\"\n self.current_box.append((event.x, event.y))\n self.canvas.create_oval(event.x - 3, event.y - 3, event.x + 3, event.y + 3, fill=\"red\")\n if len(self.current_box) == 4:\n self.rg_data.append(self.current_box.copy())\n self.draw_box(self.current_box)\n self.current_box.clear()\n\n def draw_box(self, box: List[Tuple[int, int]]) -> None:\n \"\"\"Draw a bounding box on the canvas using the provided coordinates.\"\"\"\n for i in range(4):\n self.canvas.create_line(box[i], box[(i + 1) % 4], fill=\"blue\", width=2)\n\n def remove_last_bounding_box(self) -> None:\n \"\"\"Remove the last bounding box from the list and redraw the canvas.\"\"\"\n if not self.rg_data:\n self.messagebox.showwarning(\"Warning\", \"No bounding boxes to remove.\")\n return\n self.rg_data.pop()\n self.redraw_canvas()\n\n def redraw_canvas(self) -> None:\n \"\"\"Redraw the canvas with the image and all bounding boxes.\"\"\"\n self.canvas.delete(\"all\")\n self.canvas.create_image(0, 0, anchor=self.tk.NW, image=self.canvas_image)\n for box in self.rg_data:\n self.draw_box(box)\n\n def save_to_json(self) -> None:\n \"\"\"Save the selected parking zone points to a JSON file with scaled coordinates.\"\"\"\n scale_w, scale_h = self.imgw / self.canvas.winfo_width(), self.imgh / self.canvas.winfo_height()\n data = [{\"points\": [(int(x * scale_w), int(y * scale_h)) for x, y in box]} for box in self.rg_data]\n\n from io import StringIO # Function level import, as it's only required to store coordinates\n\n write_buffer = StringIO()\n json.dump(data, write_buffer, indent=4)\n with open(\"bounding_boxes.json\", \"w\", encoding=\"utf-8\") as f:\n f.write(write_buffer.getvalue())\n self.messagebox.showinfo(\"Success\", \"Bounding boxes saved to bounding_boxes.json\")", "chunk_type": "class", "name": "ParkingPtsSelection", "file_path": "ultralytics\\ultralytics\\solutions\\parking_management.py", "start_line": 14, "end_line": 175, "start_col": 0, "end_col": 90, "parent_name": null, "docstring": "A class for selecting and managing parking zone points on images using a Tkinter-based UI.\n\nThis class provides functionality to upload an image, select points to define parking zones, and save the\nselected points to a JSON file. It uses Tkinter for the graphical user interface.\n\nAttributes:\n tk (module): The Tkinter module for GUI operations.\n filedialog (module): Tkinter's filedialog module for file selection operations.\n messagebox (module): Tkinter's messagebox module for displaying message boxes.\n master (tk.Tk): The main Tkinter window.\n canvas (tk.Canvas): The canvas widget for displaying the image and drawing bounding boxes.\n image (PIL.Image.Image): The uploaded image.\n canvas_image (ImageTk.PhotoImage): The image displayed on the canvas.\n rg_data (List[List[Tuple[int, int]]]): List of bounding boxes, each defined by 4 points.\n current_box (List[Tuple[int, int]]): Temporary storage for the points of the current bounding box.\n imgw (int): Original width of the uploaded image.\n imgh (int): Original height of the uploaded image.\n canvas_max_width (int): Maximum width of the canvas.\n canvas_max_height (int): Maximum height of the canvas.\n\nMethods:\n initialize_properties: Initialize properties for image, canvas, bounding boxes, and dimensions.\n upload_image: Upload and display an image on the canvas, resizing it to fit within specified dimensions.\n on_canvas_click: Handle mouse clicks to add points for bounding boxes on the canvas.\n draw_box: Draw a bounding box on the canvas using the provided coordinates.\n remove_last_bounding_box: Remove the last bounding box from the list and redraw the canvas.\n redraw_canvas: Redraw the canvas with the image and all bounding boxes.\n save_to_json: Save the selected parking zone points to a JSON file with scaled coordinates.\n\nExamples:\n >>> parking_selector = ParkingPtsSelection()\n >>> # Use the GUI to upload an image, select parking zones, and save the data", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "json", "typing.Any", "typing.List", "typing.Tuple", "cv2", "numpy", "ultralytics.solutions.solutions.BaseSolution", "ultralytics.solutions.solutions.SolutionAnnotator", "ultralytics.solutions.solutions.SolutionResults", "ultralytics.utils.LOGGER", "ultralytics.utils.checks.check_imshow", "PIL.Image", "PIL.ImageTk", "io.StringIO", "tkinter", "tkinter.filedialog", "tkinter.messagebox", "platform" ], "chunk_id": "class_ParkingPtsSelection_a14389a1" }, { "content": "class ParkingManagement(BaseSolution):\n \"\"\"\n Manages parking occupancy and availability using YOLO model for real-time monitoring and visualization.\n\n This class extends BaseSolution to provide functionality for parking lot management, including detection of\n occupied spaces, visualization of parking regions, and display of occupancy statistics.\n\n Attributes:\n json_file (str): Path to the JSON file containing parking region details.\n json (List[Dict]): Loaded JSON data containing parking region information.\n pr_info (Dict[str, int]): Dictionary storing parking information (Occupancy and Available spaces).\n arc (Tuple[int, int, int]): RGB color tuple for available region visualization.\n occ (Tuple[int, int, int]): RGB color tuple for occupied region visualization.\n dc (Tuple[int, int, int]): RGB color tuple for centroid visualization of detected objects.\n\n Methods:\n process: Process the input image for parking lot management and visualization.\n\n Examples:\n >>> from ultralytics.solutions import ParkingManagement\n >>> parking_manager = ParkingManagement(model=\"yolo11n.pt\", json_file=\"parking_regions.json\")\n >>> print(f\"Occupied spaces: {parking_manager.pr_info['Occupancy']}\")\n >>> print(f\"Available spaces: {parking_manager.pr_info['Available']}\")\n \"\"\"\n\n def __init__(self, **kwargs: Any) -> None:\n \"\"\"Initialize the parking management system with a YOLO model and visualization settings.\"\"\"\n super().__init__(**kwargs)\n\n self.json_file = self.CFG[\"json_file\"] # Load parking regions JSON data\n if self.json_file is None:\n LOGGER.warning(\"json_file argument missing. Parking region details required.\")\n raise ValueError(\"❌ Json file path can not be empty\")\n\n with open(self.json_file) as f:\n self.json = json.load(f)\n\n self.pr_info = {\"Occupancy\": 0, \"Available\": 0} # Dictionary for parking information\n\n self.arc = (0, 0, 255) # Available region color\n self.occ = (0, 255, 0) # Occupied region color\n self.dc = (255, 0, 189) # Centroid color for each box\n\n def process(self, im0: np.ndarray) -> SolutionResults:\n \"\"\"\n Process the input image for parking lot management and visualization.\n\n This function analyzes the input image, extracts tracks, and determines the occupancy status of parking\n regions defined in the JSON file. It annotates the image with occupied and available parking spots,\n and updates the parking information.\n\n Args:\n im0 (np.ndarray): The input inference image.\n\n Returns:\n (SolutionResults): Contains processed image `plot_im`, 'filled_slots' (number of occupied parking slots),\n 'available_slots' (number of available parking slots), and 'total_tracks' (total number of tracked objects).\n\n Examples:\n >>> parking_manager = ParkingManagement(json_file=\"parking_regions.json\")\n >>> image = cv2.imread(\"parking_lot.jpg\")\n >>> results = parking_manager.process(image)\n \"\"\"\n self.extract_tracks(im0) # Extract tracks from im0\n es, fs = len(self.json), 0 # Empty slots, filled slots\n annotator = SolutionAnnotator(im0, self.line_width) # Initialize annotator\n\n for region in self.json:\n # Convert points to a NumPy array with the correct dtype and reshape properly\n pts_array = np.array(region[\"points\"], dtype=np.int32).reshape((-1, 1, 2))\n rg_occupied = False # Occupied region initialization\n for box, cls in zip(self.boxes, self.clss):\n xc, yc = int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2)\n dist = cv2.pointPolygonTest(pts_array, (xc, yc), False)\n if dist >= 0:\n # cv2.circle(im0, (xc, yc), radius=self.line_width * 4, color=self.dc, thickness=-1)\n annotator.display_objects_labels(\n im0, self.model.names[int(cls)], (104, 31, 17), (255, 255, 255), xc, yc, 10\n )\n rg_occupied = True\n break\n fs, es = (fs + 1, es - 1) if rg_occupied else (fs, es)\n # Plot regions\n cv2.polylines(im0, [pts_array], isClosed=True, color=self.occ if rg_occupied else self.arc, thickness=2)\n\n self.pr_info[\"Occupancy\"], self.pr_info[\"Available\"] = fs, es\n\n annotator.display_analytics(im0, self.pr_info, (104, 31, 17), (255, 255, 255), 10)\n\n plot_im = annotator.result()\n self.display_output(plot_im) # Display output with base class function\n\n # Return SolutionResults\n return SolutionResults(\n plot_im=plot_im,\n filled_slots=self.pr_info[\"Occupancy\"],\n available_slots=self.pr_info[\"Available\"],\n total_tracks=len(self.track_ids),\n )", "chunk_type": "class", "name": "ParkingManagement", "file_path": "ultralytics\\ultralytics\\solutions\\parking_management.py", "start_line": 178, "end_line": 276, "start_col": 0, "end_col": 9, "parent_name": null, "docstring": "Manages parking occupancy and availability using YOLO model for real-time monitoring and visualization.\n\nThis class extends BaseSolution to provide functionality for parking lot management, including detection of\noccupied spaces, visualization of parking regions, and display of occupancy statistics.\n\nAttributes:\n json_file (str): Path to the JSON file containing parking region details.\n json (List[Dict]): Loaded JSON data containing parking region information.\n pr_info (Dict[str, int]): Dictionary storing parking information (Occupancy and Available spaces).\n arc (Tuple[int, int, int]): RGB color tuple for available region visualization.\n occ (Tuple[int, int, int]): RGB color tuple for occupied region visualization.\n dc (Tuple[int, int, int]): RGB color tuple for centroid visualization of detected objects.\n\nMethods:\n process: Process the input image for parking lot management and visualization.\n\nExamples:\n >>> from ultralytics.solutions import ParkingManagement\n >>> parking_manager = ParkingManagement(model=\"yolo11n.pt\", json_file=\"parking_regions.json\")\n >>> print(f\"Occupied spaces: {parking_manager.pr_info['Occupancy']}\")\n >>> print(f\"Available spaces: {parking_manager.pr_info['Available']}\")", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "json", "typing.Any", "typing.List", "typing.Tuple", "cv2", "numpy", "ultralytics.solutions.solutions.BaseSolution", "ultralytics.solutions.solutions.SolutionAnnotator", "ultralytics.solutions.solutions.SolutionResults", "ultralytics.utils.LOGGER", "ultralytics.utils.checks.check_imshow", "PIL.Image", "PIL.ImageTk", "io.StringIO", "tkinter", "tkinter.filedialog", "tkinter.messagebox", "platform", "BaseSolution" ], "chunk_id": "class_ParkingManagement_08332990" }, { "content": "from typing import Any", "chunk_type": "import", "name": "Any", "file_path": "ultralytics\\ultralytics\\solutions\\queue_management.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 22, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Any_9aed8648" }, { "content": "from ultralytics.solutions.solutions import BaseSolution, SolutionAnnotator, SolutionResults", "chunk_type": "import", "name": "BaseSolution, SolutionAnnotator, SolutionResults", "file_path": "ultralytics\\ultralytics\\solutions\\queue_management.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 92, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_BaseSolution, SolutionAnnotator, SolutionResults_78f87b94" }, { "content": "from ultralytics.utils.plotting import colors", "chunk_type": "import", "name": "colors", "file_path": "ultralytics\\ultralytics\\solutions\\queue_management.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 45, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_colors_46e65bda" }, { "content": "class QueueManager(BaseSolution):\n \"\"\"\n Manages queue counting in real-time video streams based on object tracks.\n\n This class extends BaseSolution to provide functionality for tracking and counting objects within a specified\n region in video frames.\n\n Attributes:\n counts (int): The current count of objects in the queue.\n rect_color (Tuple[int, int, int]): RGB color tuple for drawing the queue region rectangle.\n region_length (int): The number of points defining the queue region.\n track_line (List[Tuple[int, int]]): List of track line coordinates.\n track_history (Dict[int, List[Tuple[int, int]]]): Dictionary storing tracking history for each object.\n\n Methods:\n initialize_region: Initialize the queue region.\n process: Process a single frame for queue management.\n extract_tracks: Extract object tracks from the current frame.\n store_tracking_history: Store the tracking history for an object.\n display_output: Display the processed output.\n\n Examples:\n >>> cap = cv2.VideoCapture(\"path/to/video.mp4\")\n >>> queue_manager = QueueManager(region=[100, 100, 200, 200, 300, 300])\n >>> while cap.isOpened():\n >>> success, im0 = cap.read()\n >>> if not success:\n >>> break\n >>> results = queue_manager.process(im0)\n \"\"\"\n\n def __init__(self, **kwargs: Any) -> None:\n \"\"\"Initialize the QueueManager with parameters for tracking and counting objects in a video stream.\"\"\"\n super().__init__(**kwargs)\n self.initialize_region()\n self.counts = 0 # Queue counts information\n self.rect_color = (255, 255, 255) # Rectangle color for visualization\n self.region_length = len(self.region) # Store region length for further usage\n\n def process(self, im0) -> SolutionResults:\n \"\"\"\n Process queue management for a single frame of video.\n\n Args:\n im0 (np.ndarray): Input image for processing, typically a frame from a video stream.\n\n Returns:\n (SolutionResults): Contains processed image `im0`, 'queue_count' (int, number of objects in the queue) and\n 'total_tracks' (int, total number of tracked objects).\n\n Examples:\n >>> queue_manager = QueueManager()\n >>> frame = cv2.imread(\"frame.jpg\")\n >>> results = queue_manager.process(frame)\n \"\"\"\n self.counts = 0 # Reset counts every frame\n self.extract_tracks(im0) # Extract tracks from the current frame\n annotator = SolutionAnnotator(im0, line_width=self.line_width) # Initialize annotator\n annotator.draw_region(reg_pts=self.region, color=self.rect_color, thickness=self.line_width * 2) # Draw region\n\n for box, track_id, cls, conf in zip(self.boxes, self.track_ids, self.clss, self.confs):\n # Draw bounding box and counting region\n annotator.box_label(box, label=self.adjust_box_label(cls, conf, track_id), color=colors(track_id, True))\n self.store_tracking_history(track_id, box) # Store track history\n\n # Cache frequently accessed attributes\n track_history = self.track_history.get(track_id, [])\n\n # Store previous position of track and check if the object is inside the counting region\n prev_position = None\n if len(track_history) > 1:\n prev_position = track_history[-2]\n if self.region_length >= 3 and prev_position and self.r_s.contains(self.Point(self.track_line[-1])):\n self.counts += 1\n\n # Display queue counts\n annotator.queue_counts_display(\n f\"Queue Counts : {str(self.counts)}\",\n points=self.region,\n region_color=self.rect_color,\n txt_color=(104, 31, 17),\n )\n plot_im = annotator.result()\n self.display_output(plot_im) # Display output with base class function\n\n # Return a SolutionResults object with processed data\n return SolutionResults(plot_im=plot_im, queue_count=self.counts, total_tracks=len(self.track_ids))", "chunk_type": "class", "name": "QueueManager", "file_path": "ultralytics\\ultralytics\\solutions\\queue_management.py", "start_line": 9, "end_line": 95, "start_col": 0, "end_col": 106, "parent_name": null, "docstring": "Manages queue counting in real-time video streams based on object tracks.\n\nThis class extends BaseSolution to provide functionality for tracking and counting objects within a specified\nregion in video frames.\n\nAttributes:\n counts (int): The current count of objects in the queue.\n rect_color (Tuple[int, int, int]): RGB color tuple for drawing the queue region rectangle.\n region_length (int): The number of points defining the queue region.\n track_line (List[Tuple[int, int]]): List of track line coordinates.\n track_history (Dict[int, List[Tuple[int, int]]]): Dictionary storing tracking history for each object.\n\nMethods:\n initialize_region: Initialize the queue region.\n process: Process a single frame for queue management.\n extract_tracks: Extract object tracks from the current frame.\n store_tracking_history: Store the tracking history for an object.\n display_output: Display the processed output.\n\nExamples:\n >>> cap = cv2.VideoCapture(\"path/to/video.mp4\")\n >>> queue_manager = QueueManager(region=[100, 100, 200, 200, 300, 300])\n >>> while cap.isOpened():\n >>> success, im0 = cap.read()\n >>> if not success:\n >>> break\n >>> results = queue_manager.process(im0)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.Any", "ultralytics.solutions.solutions.BaseSolution", "ultralytics.solutions.solutions.SolutionAnnotator", "ultralytics.solutions.solutions.SolutionResults", "ultralytics.utils.plotting.colors", "BaseSolution" ], "chunk_id": "class_QueueManager_74fd260d" }, { "content": "from typing import Any, Dict, List, Tuple", "chunk_type": "import", "name": "Any, Dict, List, Tuple", "file_path": "ultralytics\\ultralytics\\solutions\\region_counter.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 41, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Any, Dict, List, Tuple_388808c1" }, { "content": "import numpy as np", "chunk_type": "import", "name": "numpy", "file_path": "ultralytics\\ultralytics\\solutions\\region_counter.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_numpy_17796638" }, { "content": "from ultralytics.solutions.solutions import BaseSolution, SolutionAnnotator, SolutionResults", "chunk_type": "import", "name": "BaseSolution, SolutionAnnotator, SolutionResults", "file_path": "ultralytics\\ultralytics\\solutions\\region_counter.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 92, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_BaseSolution, SolutionAnnotator, SolutionResults_0b029d21" }, { "content": "from ultralytics.utils.plotting import colors", "chunk_type": "import", "name": "colors", "file_path": "ultralytics\\ultralytics\\solutions\\region_counter.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 45, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_colors_c1b76362" }, { "content": "class RegionCounter(BaseSolution):\n \"\"\"\n A class for real-time counting of objects within user-defined regions in a video stream.\n\n This class inherits from `BaseSolution` and provides functionality to define polygonal regions in a video frame,\n track objects, and count those objects that pass through each defined region. Useful for applications requiring\n counting in specified areas, such as monitoring zones or segmented sections.\n\n Attributes:\n region_template (dict): Template for creating new counting regions with default attributes including name,\n polygon coordinates, and display colors.\n counting_regions (list): List storing all defined regions, where each entry is based on `region_template`\n and includes specific region settings like name, coordinates, and color.\n region_counts (dict): Dictionary storing the count of objects for each named region.\n\n Methods:\n add_region: Add a new counting region with specified attributes.\n process: Process video frames to count objects in each region.\n initialize_regions: Initialize zones to count the objects in each one. Zones could be multiple as well.\n\n Examples:\n Initialize a RegionCounter and add a counting region\n >>> counter = RegionCounter()\n >>> counter.add_region(\"Zone1\", [(100, 100), (200, 100), (200, 200), (100, 200)], (255, 0, 0), (255, 255, 255))\n >>> results = counter.process(frame)\n >>> print(f\"Total tracks: {results.total_tracks}\")\n \"\"\"\n\n def __init__(self, **kwargs: Any) -> None:\n \"\"\"Initialize the RegionCounter for real-time object counting in user-defined regions.\"\"\"\n super().__init__(**kwargs)\n self.region_template = {\n \"name\": \"Default Region\",\n \"polygon\": None,\n \"counts\": 0,\n \"region_color\": (255, 255, 255),\n \"text_color\": (0, 0, 0),\n }\n self.region_counts = {}\n self.counting_regions = []\n self.initialize_regions()\n\n def add_region(\n self,\n name: str,\n polygon_points: List[Tuple],\n region_color: Tuple[int, int, int],\n text_color: Tuple[int, int, int],\n ) -> Dict[str, Any]:\n \"\"\"\n Add a new region to the counting list based on the provided template with specific attributes.\n\n Args:\n name (str): Name assigned to the new region.\n polygon_points (List[Tuple]): List of (x, y) coordinates defining the region's polygon.\n region_color (Tuple[int, int, int]): BGR color for region visualization.\n text_color (Tuple[int, int, int]): BGR color for the text within the region.\n\n Returns:\n (Dict[str, any]): Returns a dictionary including the region information i.e. name, region_color etc.\n \"\"\"\n region = self.region_template.copy()\n region.update(\n {\n \"name\": name,\n \"polygon\": self.Polygon(polygon_points),\n \"region_color\": region_color,\n \"text_color\": text_color,\n }\n )\n self.counting_regions.append(region)\n return region\n\n def initialize_regions(self):\n \"\"\"Initialize regions only once.\"\"\"\n if self.region is None:\n self.initialize_region()\n if not isinstance(self.region, dict): # Ensure self.region is initialized and structured as a dictionary\n self.region = {\"Region#01\": self.region}\n for i, (name, pts) in enumerate(self.region.items()):\n region = self.add_region(name, pts, colors(i, True), (255, 255, 255))\n region[\"prepared_polygon\"] = self.prep(region[\"polygon\"])\n\n def process(self, im0: np.ndarray) -> SolutionResults:\n \"\"\"\n Process the input frame to detect and count objects within each defined region.\n\n Args:\n im0 (np.ndarray): Input image frame where objects and regions are annotated.\n\n Returns:\n (SolutionResults): Contains processed image `plot_im`, 'total_tracks' (int, total number of tracked objects),\n and 'region_counts' (dict, counts of objects per region).\n \"\"\"\n self.extract_tracks(im0)\n annotator = SolutionAnnotator(im0, line_width=self.line_width)\n\n for box, cls, track_id, conf in zip(self.boxes, self.clss, self.track_ids, self.confs):\n annotator.box_label(box, label=self.adjust_box_label(cls, conf, track_id), color=colors(track_id, True))\n center = self.Point(((box[0] + box[2]) / 2, (box[1] + box[3]) / 2))\n for region in self.counting_regions:\n if region[\"prepared_polygon\"].contains(center):\n region[\"counts\"] += 1\n self.region_counts[region[\"name\"]] = region[\"counts\"]\n\n # Display region counts\n for region in self.counting_regions:\n x1, y1, x2, y2 = map(int, region[\"polygon\"].bounds)\n pts = [(x1, y1), (x2, y1), (x2, y2), (x1, y2)]\n annotator.draw_region(pts, region[\"region_color\"], self.line_width * 2)\n annotator.text_label(\n [x1, y1, x2, y2],\n label=str(region[\"counts\"]),\n color=region[\"region_color\"],\n txt_color=region[\"text_color\"],\n margin=self.line_width * 4,\n )\n region[\"counts\"] = 0 # Reset for next frame\n plot_im = annotator.result()\n self.display_output(plot_im)\n\n return SolutionResults(plot_im=plot_im, total_tracks=len(self.track_ids), region_counts=self.region_counts)", "chunk_type": "class", "name": "RegionCounter", "file_path": "ultralytics\\ultralytics\\solutions\\region_counter.py", "start_line": 11, "end_line": 132, "start_col": 0, "end_col": 115, "parent_name": null, "docstring": "A class for real-time counting of objects within user-defined regions in a video stream.\n\nThis class inherits from `BaseSolution` and provides functionality to define polygonal regions in a video frame,\ntrack objects, and count those objects that pass through each defined region. Useful for applications requiring\ncounting in specified areas, such as monitoring zones or segmented sections.\n\nAttributes:\n region_template (dict): Template for creating new counting regions with default attributes including name,\n polygon coordinates, and display colors.\n counting_regions (list): List storing all defined regions, where each entry is based on `region_template`\n and includes specific region settings like name, coordinates, and color.\n region_counts (dict): Dictionary storing the count of objects for each named region.\n\nMethods:\n add_region: Add a new counting region with specified attributes.\n process: Process video frames to count objects in each region.\n initialize_regions: Initialize zones to count the objects in each one. Zones could be multiple as well.\n\nExamples:\n Initialize a RegionCounter and add a counting region\n >>> counter = RegionCounter()\n >>> counter.add_region(\"Zone1\", [(100, 100), (200, 100), (200, 200), (100, 200)], (255, 0, 0), (255, 255, 255))\n >>> results = counter.process(frame)\n >>> print(f\"Total tracks: {results.total_tracks}\")", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "numpy", "ultralytics.solutions.solutions.BaseSolution", "ultralytics.solutions.solutions.SolutionAnnotator", "ultralytics.solutions.solutions.SolutionResults", "ultralytics.utils.plotting.colors", "BaseSolution" ], "chunk_id": "class_RegionCounter_9cb98056" }, { "content": "from typing import Any", "chunk_type": "import", "name": "Any", "file_path": "ultralytics\\ultralytics\\solutions\\security_alarm.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 22, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Any_d5fd2f85" }, { "content": "from ultralytics.solutions.solutions import BaseSolution, SolutionAnnotator, SolutionResults", "chunk_type": "import", "name": "BaseSolution, SolutionAnnotator, SolutionResults", "file_path": "ultralytics\\ultralytics\\solutions\\security_alarm.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 92, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_BaseSolution, SolutionAnnotator, SolutionResults_a1445258" }, { "content": "from ultralytics.utils import LOGGER", "chunk_type": "import", "name": "LOGGER", "file_path": "ultralytics\\ultralytics\\solutions\\security_alarm.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 36, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_LOGGER_934d92dd" }, { "content": "from ultralytics.utils.plotting import colors", "chunk_type": "import", "name": "colors", "file_path": "ultralytics\\ultralytics\\solutions\\security_alarm.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 45, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_colors_0fe10d62" }, { "content": "class SecurityAlarm(BaseSolution):\n \"\"\"\n A class to manage security alarm functionalities for real-time monitoring.\n\n This class extends the BaseSolution class and provides features to monitor objects in a frame, send email\n notifications when specific thresholds are exceeded for total detections, and annotate the output frame for\n visualization.\n\n Attributes:\n email_sent (bool): Flag to track if an email has already been sent for the current event.\n records (int): Threshold for the number of detected objects to trigger an alert.\n server (smtplib.SMTP): SMTP server connection for sending email alerts.\n to_email (str): Recipient's email address for alerts.\n from_email (str): Sender's email address for alerts.\n\n Methods:\n authenticate: Set up email server authentication for sending alerts.\n send_email: Send an email notification with details and an image attachment.\n process: Monitor the frame, process detections, and trigger alerts if thresholds are crossed.\n\n Examples:\n >>> security = SecurityAlarm()\n >>> security.authenticate(\"abc@gmail.com\", \"1111222233334444\", \"xyz@gmail.com\")\n >>> frame = cv2.imread(\"frame.jpg\")\n >>> results = security.process(frame)\n \"\"\"\n\n def __init__(self, **kwargs: Any) -> None:\n \"\"\"\n Initialize the SecurityAlarm class with parameters for real-time object monitoring.\n\n Args:\n **kwargs (Any): Additional keyword arguments passed to the parent class.\n \"\"\"\n super().__init__(**kwargs)\n self.email_sent = False\n self.records = self.CFG[\"records\"]\n self.server = None\n self.to_email = \"\"\n self.from_email = \"\"\n\n def authenticate(self, from_email: str, password: str, to_email: str) -> None:\n \"\"\"\n Authenticate the email server for sending alert notifications.\n\n Args:\n from_email (str): Sender's email address.\n password (str): Password for the sender's email account.\n to_email (str): Recipient's email address.\n\n This method initializes a secure connection with the SMTP server and logs in using the provided credentials.\n\n Examples:\n >>> alarm = SecurityAlarm()\n >>> alarm.authenticate(\"sender@example.com\", \"password123\", \"recipient@example.com\")\n \"\"\"\n import smtplib\n\n self.server = smtplib.SMTP(\"smtp.gmail.com: 587\")\n self.server.starttls()\n self.server.login(from_email, password)\n self.to_email = to_email\n self.from_email = from_email\n\n def send_email(self, im0, records: int = 5) -> None:\n \"\"\"\n Send an email notification with an image attachment indicating the number of objects detected.\n\n Args:\n im0 (np.ndarray): The input image or frame to be attached to the email.\n records (int, optional): The number of detected objects to be included in the email message.\n\n This method encodes the input image, composes the email message with details about the detection, and sends it\n to the specified recipient.\n\n Examples:\n >>> alarm = SecurityAlarm()\n >>> frame = cv2.imread(\"path/to/image.jpg\")\n >>> alarm.send_email(frame, records=10)\n \"\"\"\n from email.mime.image import MIMEImage\n from email.mime.multipart import MIMEMultipart\n from email.mime.text import MIMEText\n\n import cv2\n\n img_bytes = cv2.imencode(\".jpg\", im0)[1].tobytes() # Encode the image as JPEG\n\n # Create the email\n message = MIMEMultipart()\n message[\"From\"] = self.from_email\n message[\"To\"] = self.to_email\n message[\"Subject\"] = \"Security Alert\"\n\n # Add the text message body\n message_body = f\"Ultralytics ALERT!!! {records} objects have been detected!!\"\n message.attach(MIMEText(message_body))\n\n # Attach the image\n image_attachment = MIMEImage(img_bytes, name=\"ultralytics.jpg\")\n message.attach(image_attachment)\n\n # Send the email\n try:\n self.server.send_message(message)\n LOGGER.info(\"Email sent successfully!\")\n except Exception as e:\n LOGGER.error(f\"Failed to send email: {e}\")\n\n def process(self, im0) -> SolutionResults:\n \"\"\"\n Monitor the frame, process object detections, and trigger alerts if thresholds are exceeded.\n\n Args:\n im0 (np.ndarray): The input image or frame to be processed and annotated.\n\n Returns:\n (SolutionResults): Contains processed image `plot_im`, 'total_tracks' (total number of tracked objects) and\n 'email_sent' (whether an email alert was triggered).\n\n This method processes the input frame, extracts detections, annotates the frame with bounding boxes, and sends\n an email notification if the number of detected objects surpasses the specified threshold and an alert has not\n already been sent.\n\n Examples:\n >>> alarm = SecurityAlarm()\n >>> frame = cv2.imread(\"path/to/image.jpg\")\n >>> results = alarm.process(frame)\n \"\"\"\n self.extract_tracks(im0) # Extract tracks\n annotator = SolutionAnnotator(im0, line_width=self.line_width) # Initialize annotator\n\n # Iterate over bounding boxes and classes index\n for box, cls in zip(self.boxes, self.clss):\n # Draw bounding box\n annotator.box_label(box, label=self.names[cls], color=colors(cls, True))\n\n total_det = len(self.clss)\n if total_det >= self.records and not self.email_sent: # Only send email if not sent before\n self.send_email(im0, total_det)\n self.email_sent = True\n\n plot_im = annotator.result()\n self.display_output(plot_im) # Display output with base class function\n\n # Return a SolutionResults\n return SolutionResults(plot_im=plot_im, total_tracks=len(self.track_ids), email_sent=self.email_sent)", "chunk_type": "class", "name": "SecurityAlarm", "file_path": "ultralytics\\ultralytics\\solutions\\security_alarm.py", "start_line": 10, "end_line": 156, "start_col": 0, "end_col": 109, "parent_name": null, "docstring": "A class to manage security alarm functionalities for real-time monitoring.\n\nThis class extends the BaseSolution class and provides features to monitor objects in a frame, send email\nnotifications when specific thresholds are exceeded for total detections, and annotate the output frame for\nvisualization.\n\nAttributes:\n email_sent (bool): Flag to track if an email has already been sent for the current event.\n records (int): Threshold for the number of detected objects to trigger an alert.\n server (smtplib.SMTP): SMTP server connection for sending email alerts.\n to_email (str): Recipient's email address for alerts.\n from_email (str): Sender's email address for alerts.\n\nMethods:\n authenticate: Set up email server authentication for sending alerts.\n send_email: Send an email notification with details and an image attachment.\n process: Monitor the frame, process detections, and trigger alerts if thresholds are crossed.\n\nExamples:\n >>> security = SecurityAlarm()\n >>> security.authenticate(\"abc@gmail.com\", \"1111222233334444\", \"xyz@gmail.com\")\n >>> frame = cv2.imread(\"frame.jpg\")\n >>> results = security.process(frame)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.Any", "ultralytics.solutions.solutions.BaseSolution", "ultralytics.solutions.solutions.SolutionAnnotator", "ultralytics.solutions.solutions.SolutionResults", "ultralytics.utils.LOGGER", "ultralytics.utils.plotting.colors", "smtplib", "email.mime.image.MIMEImage", "email.mime.multipart.MIMEMultipart", "email.mime.text.MIMEText", "cv2", "BaseSolution" ], "chunk_id": "class_SecurityAlarm_c056dd1c" }, { "content": "import os", "chunk_type": "import", "name": "os", "file_path": "ultralytics\\ultralytics\\solutions\\similarity_search.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 9, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_os_031fa0f7" }, { "content": "from pathlib import Path", "chunk_type": "import", "name": "Path", "file_path": "ultralytics\\ultralytics\\solutions\\similarity_search.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Path_b1440c31" }, { "content": "from typing import Any, List", "chunk_type": "import", "name": "Any, List", "file_path": "ultralytics\\ultralytics\\solutions\\similarity_search.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 28, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Any, List_8a82c947" }, { "content": "import numpy as np", "chunk_type": "import", "name": "numpy", "file_path": "ultralytics\\ultralytics\\solutions\\similarity_search.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_numpy_4dc09686" }, { "content": "from PIL import Image", "chunk_type": "import", "name": "Image", "file_path": "ultralytics\\ultralytics\\solutions\\similarity_search.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 21, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Image_46aafba3" }, { "content": "from ultralytics.data.utils import IMG_FORMATS", "chunk_type": "import", "name": "IMG_FORMATS", "file_path": "ultralytics\\ultralytics\\solutions\\similarity_search.py", "start_line": 10, "end_line": 10, "start_col": 0, "end_col": 46, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_IMG_FORMATS_b9b5e2e6" }, { "content": "from ultralytics.nn.text_model import build_text_model", "chunk_type": "import", "name": "build_text_model", "file_path": "ultralytics\\ultralytics\\solutions\\similarity_search.py", "start_line": 11, "end_line": 11, "start_col": 0, "end_col": 54, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_build_text_model_9485a61e" }, { "content": "from ultralytics.utils import LOGGER", "chunk_type": "import", "name": "LOGGER", "file_path": "ultralytics\\ultralytics\\solutions\\similarity_search.py", "start_line": 12, "end_line": 12, "start_col": 0, "end_col": 36, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_LOGGER_9e0460cc" }, { "content": "from ultralytics.utils.checks import check_requirements", "chunk_type": "import", "name": "check_requirements", "file_path": "ultralytics\\ultralytics\\solutions\\similarity_search.py", "start_line": 13, "end_line": 13, "start_col": 0, "end_col": 55, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_check_requirements_ed830f39" }, { "content": "from ultralytics.utils.torch_utils import select_device", "chunk_type": "import", "name": "select_device", "file_path": "ultralytics\\ultralytics\\solutions\\similarity_search.py", "start_line": 14, "end_line": 14, "start_col": 0, "end_col": 55, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_select_device_ecc55285" }, { "content": "class VisualAISearch:\n \"\"\"\n A semantic image search system that leverages OpenCLIP for generating high-quality image and text embeddings and\n FAISS for fast similarity-based retrieval.\n\n This class aligns image and text embeddings in a shared semantic space, enabling users to search large collections\n of images using natural language queries with high accuracy and speed.\n\n Attributes:\n data (str): Directory containing images.\n device (str): Computation device, e.g., 'cpu' or 'cuda'.\n faiss_index (str): Path to the FAISS index file.\n data_path_npy (str): Path to the numpy file storing image paths.\n data_dir (Path): Path object for the data directory.\n model: Loaded CLIP model.\n index: FAISS index for similarity search.\n image_paths (List[str]): List of image file paths.\n\n Methods:\n extract_image_feature: Extract CLIP embedding from an image.\n extract_text_feature: Extract CLIP embedding from text.\n load_or_build_index: Load existing FAISS index or build new one.\n search: Perform semantic search for similar images.\n\n Examples:\n Initialize and search for images\n >>> searcher = VisualAISearch(data=\"path/to/images\", device=\"cuda\")\n >>> results = searcher.search(\"a cat sitting on a chair\", k=10)\n \"\"\"\n\n def __init__(self, **kwargs: Any) -> None:\n \"\"\"Initialize the VisualAISearch class with FAISS index and CLIP model.\"\"\"\n check_requirements(\"faiss-cpu\")\n\n self.faiss = __import__(\"faiss\")\n self.faiss_index = \"faiss.index\"\n self.data_path_npy = \"paths.npy\"\n self.data_dir = Path(kwargs.get(\"data\", \"images\"))\n self.device = select_device(kwargs.get(\"device\", \"cpu\"))\n\n if not self.data_dir.exists():\n from ultralytics.utils import ASSETS_URL\n\n LOGGER.warning(f\"{self.data_dir} not found. Downloading images.zip from {ASSETS_URL}/images.zip\")\n from ultralytics.utils.downloads import safe_download\n\n safe_download(url=f\"{ASSETS_URL}/images.zip\", unzip=True, retry=3)\n self.data_dir = Path(\"images\")\n\n self.model = build_text_model(\"clip:ViT-B/32\", device=self.device)\n\n self.index = None\n self.image_paths = []\n\n self.load_or_build_index()\n\n def extract_image_feature(self, path: Path) -> np.ndarray:\n \"\"\"Extract CLIP image embedding from the given image path.\"\"\"\n return self.model.encode_image(Image.open(path)).cpu().numpy()\n\n def extract_text_feature(self, text: str) -> np.ndarray:\n \"\"\"Extract CLIP text embedding from the given text query.\"\"\"\n return self.model.encode_text(self.model.tokenize([text])).cpu().numpy()\n\n def load_or_build_index(self) -> None:\n \"\"\"\n Load existing FAISS index or build a new one from image features.\n\n Checks if FAISS index and image paths exist on disk. If found, loads them directly. Otherwise, builds a new\n index by extracting features from all images in the data directory, normalizes the features, and saves both the\n index and image paths for future use.\n \"\"\"\n # Check if the FAISS index and corresponding image paths already exist\n if Path(self.faiss_index).exists() and Path(self.data_path_npy).exists():\n LOGGER.info(\"Loading existing FAISS index...\")\n self.index = self.faiss.read_index(self.faiss_index) # Load the FAISS index from disk\n self.image_paths = np.load(self.data_path_npy) # Load the saved image path list\n return # Exit the function as the index is successfully loaded\n\n # If the index doesn't exist, start building it from scratch\n LOGGER.info(\"Building FAISS index from images...\")\n vectors = [] # List to store feature vectors of images\n\n # Iterate over all image files in the data directory\n for file in self.data_dir.iterdir():\n # Skip files that are not valid image formats\n if file.suffix.lower().lstrip(\".\") not in IMG_FORMATS:\n continue\n try:\n # Extract feature vector for the image and add to the list\n vectors.append(self.extract_image_feature(file))\n self.image_paths.append(file.name) # Store the corresponding image name\n except Exception as e:\n LOGGER.warning(f\"Skipping {file.name}: {e}\")\n\n # If no vectors were successfully created, raise an error\n if not vectors:\n raise RuntimeError(\"No image embeddings could be generated.\")\n\n vectors = np.vstack(vectors).astype(\"float32\") # Stack all vectors into a NumPy array and convert to float32\n self.faiss.normalize_L2(vectors) # Normalize vectors to unit length for cosine similarity\n\n self.index = self.faiss.IndexFlatIP(vectors.shape[1]) # Create a new FAISS index using inner product\n self.index.add(vectors) # Add the normalized vectors to the FAISS index\n self.faiss.write_index(self.index, self.faiss_index) # Save the newly built FAISS index to disk\n np.save(self.data_path_npy, np.array(self.image_paths)) # Save the list of image paths to disk\n\n LOGGER.info(f\"Indexed {len(self.image_paths)} images.\")\n\n def search(self, query: str, k: int = 30, similarity_thresh: float = 0.1) -> List[str]:\n \"\"\"\n Return top-k semantically similar images to the given query.\n\n Args:\n query (str): Natural language text query to search for.\n k (int, optional): Maximum number of results to return.\n similarity_thresh (float, optional): Minimum similarity threshold for filtering results.\n\n Returns:\n (List[str]): List of image filenames ranked by similarity score.\n\n Examples:\n Search for images matching a query\n >>> searcher = VisualAISearch(data=\"images\")\n >>> results = searcher.search(\"red car\", k=5, similarity_thresh=0.2)\n \"\"\"\n text_feat = self.extract_text_feature(query).astype(\"float32\")\n self.faiss.normalize_L2(text_feat)\n\n D, index = self.index.search(text_feat, k)\n results = [\n (self.image_paths[i], float(D[0][idx])) for idx, i in enumerate(index[0]) if D[0][idx] >= similarity_thresh\n ]\n results.sort(key=lambda x: x[1], reverse=True)\n\n LOGGER.info(\"\\nRanked Results:\")\n for name, score in results:\n LOGGER.info(f\" - {name} | Similarity: {score:.4f}\")\n\n return [r[0] for r in results]\n\n def __call__(self, query: str) -> List[str]:\n \"\"\"Direct call interface for the search function.\"\"\"\n return self.search(query)", "chunk_type": "class", "name": "VisualAISearch", "file_path": "ultralytics\\ultralytics\\solutions\\similarity_search.py", "start_line": 19, "end_line": 162, "start_col": 0, "end_col": 33, "parent_name": null, "docstring": "A semantic image search system that leverages OpenCLIP for generating high-quality image and text embeddings and\nFAISS for fast similarity-based retrieval.\n\nThis class aligns image and text embeddings in a shared semantic space, enabling users to search large collections\nof images using natural language queries with high accuracy and speed.\n\nAttributes:\n data (str): Directory containing images.\n device (str): Computation device, e.g., 'cpu' or 'cuda'.\n faiss_index (str): Path to the FAISS index file.\n data_path_npy (str): Path to the numpy file storing image paths.\n data_dir (Path): Path object for the data directory.\n model: Loaded CLIP model.\n index: FAISS index for similarity search.\n image_paths (List[str]): List of image file paths.\n\nMethods:\n extract_image_feature: Extract CLIP embedding from an image.\n extract_text_feature: Extract CLIP embedding from text.\n load_or_build_index: Load existing FAISS index or build new one.\n search: Perform semantic search for similar images.\n\nExamples:\n Initialize and search for images\n >>> searcher = VisualAISearch(data=\"path/to/images\", device=\"cuda\")\n >>> results = searcher.search(\"a cat sitting on a chair\", k=10)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "os", "pathlib.Path", "typing.Any", "typing.List", "numpy", "PIL.Image", "ultralytics.data.utils.IMG_FORMATS", "ultralytics.nn.text_model.build_text_model", "ultralytics.utils.LOGGER", "ultralytics.utils.checks.check_requirements", "ultralytics.utils.torch_utils.select_device", "flask.Flask", "flask.render_template", "flask.request", "ultralytics.utils.ASSETS_URL", "ultralytics.utils.downloads.safe_download" ], "chunk_id": "class_VisualAISearch_06b9dda1" }, { "content": "class SearchApp:\n \"\"\"\n A Flask-based web interface for semantic image search with natural language queries.\n\n This class provides a clean, responsive frontend that enables users to input natural language queries and\n instantly view the most relevant images retrieved from the indexed database.\n\n Attributes:\n render_template: Flask template rendering function.\n request: Flask request object.\n searcher (VisualAISearch): Instance of the VisualAISearch class.\n app (Flask): Flask application instance.\n\n Methods:\n index: Process user queries and display search results.\n run: Start the Flask web application.\n\n Examples:\n Start a search application\n >>> app = SearchApp(data=\"path/to/images\", device=\"cuda\")\n >>> app.run(debug=True)\n \"\"\"\n\n def __init__(self, data: str = \"images\", device: str = None) -> None:\n \"\"\"\n Initialize the SearchApp with VisualAISearch backend.\n\n Args:\n data (str, optional): Path to directory containing images to index and search.\n device (str, optional): Device to run inference on (e.g. 'cpu', 'cuda').\n \"\"\"\n check_requirements(\"flask>=3.0.1\")\n from flask import Flask, render_template, request\n\n self.render_template = render_template\n self.request = request\n self.searcher = VisualAISearch(data=data, device=device)\n self.app = Flask(\n __name__,\n template_folder=\"templates\",\n static_folder=Path(data).resolve(), # Absolute path to serve images\n static_url_path=\"/images\", # URL prefix for images\n )\n self.app.add_url_rule(\"/\", view_func=self.index, methods=[\"GET\", \"POST\"])\n\n def index(self) -> str:\n \"\"\"Process user query and display search results in the web interface.\"\"\"\n results = []\n if self.request.method == \"POST\":\n query = self.request.form.get(\"query\", \"\").strip()\n results = self.searcher(query)\n return self.render_template(\"similarity-search.html\", results=results)\n\n def run(self, debug: bool = False) -> None:\n \"\"\"Start the Flask web application server.\"\"\"\n self.app.run(debug=debug)", "chunk_type": "class", "name": "SearchApp", "file_path": "ultralytics\\ultralytics\\solutions\\similarity_search.py", "start_line": 165, "end_line": 220, "start_col": 0, "end_col": 33, "parent_name": null, "docstring": "A Flask-based web interface for semantic image search with natural language queries.\n\nThis class provides a clean, responsive frontend that enables users to input natural language queries and\ninstantly view the most relevant images retrieved from the indexed database.\n\nAttributes:\n render_template: Flask template rendering function.\n request: Flask request object.\n searcher (VisualAISearch): Instance of the VisualAISearch class.\n app (Flask): Flask application instance.\n\nMethods:\n index: Process user queries and display search results.\n run: Start the Flask web application.\n\nExamples:\n Start a search application\n >>> app = SearchApp(data=\"path/to/images\", device=\"cuda\")\n >>> app.run(debug=True)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "os", "pathlib.Path", "typing.Any", "typing.List", "numpy", "PIL.Image", "ultralytics.data.utils.IMG_FORMATS", "ultralytics.nn.text_model.build_text_model", "ultralytics.utils.LOGGER", "ultralytics.utils.checks.check_requirements", "ultralytics.utils.torch_utils.select_device", "flask.Flask", "flask.render_template", "flask.request", "ultralytics.utils.ASSETS_URL", "ultralytics.utils.downloads.safe_download" ], "chunk_id": "class_SearchApp_25d60119" }, { "content": "import math", "chunk_type": "import", "name": "math", "file_path": "ultralytics\\ultralytics\\solutions\\solutions.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 11, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_math_36e75246" }, { "content": "from collections import defaultdict", "chunk_type": "import", "name": "defaultdict", "file_path": "ultralytics\\ultralytics\\solutions\\solutions.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 35, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_defaultdict_8a57e2e2" }, { "content": "from functools import lru_cache", "chunk_type": "import", "name": "lru_cache", "file_path": "ultralytics\\ultralytics\\solutions\\solutions.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 31, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_lru_cache_73213f9b" }, { "content": "from typing import Any, Dict, List, Optional, Tuple", "chunk_type": "import", "name": "Any, Dict, List, Optional, Tuple", "file_path": "ultralytics\\ultralytics\\solutions\\solutions.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 51, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Any, Dict, List, Optional, Tuple_ed8c10c2" }, { "content": "import cv2", "chunk_type": "import", "name": "cv2", "file_path": "ultralytics\\ultralytics\\solutions\\solutions.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 10, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_cv2_f1bc611f" }, { "content": "import numpy as np", "chunk_type": "import", "name": "numpy", "file_path": "ultralytics\\ultralytics\\solutions\\solutions.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_numpy_13d88f0c" }, { "content": "from ultralytics import YOLO", "chunk_type": "import", "name": "YOLO", "file_path": "ultralytics\\ultralytics\\solutions\\solutions.py", "start_line": 11, "end_line": 11, "start_col": 0, "end_col": 28, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_YOLO_6b9996f5" }, { "content": "from ultralytics.solutions.config import SolutionConfig", "chunk_type": "import", "name": "SolutionConfig", "file_path": "ultralytics\\ultralytics\\solutions\\solutions.py", "start_line": 12, "end_line": 12, "start_col": 0, "end_col": 55, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_SolutionConfig_51782812" }, { "content": "from ultralytics.utils import ASSETS_URL, LOGGER, ops", "chunk_type": "import", "name": "ASSETS_URL, LOGGER, ops", "file_path": "ultralytics\\ultralytics\\solutions\\solutions.py", "start_line": 13, "end_line": 13, "start_col": 0, "end_col": 53, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_ASSETS_URL, LOGGER, ops_c02d76db" }, { "content": "from ultralytics.utils.checks import check_imshow, check_requirements", "chunk_type": "import", "name": "check_imshow, check_requirements", "file_path": "ultralytics\\ultralytics\\solutions\\solutions.py", "start_line": 14, "end_line": 14, "start_col": 0, "end_col": 69, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_check_imshow, check_requirements_58a3f4ea" }, { "content": "from ultralytics.utils.plotting import Annotator", "chunk_type": "import", "name": "Annotator", "file_path": "ultralytics\\ultralytics\\solutions\\solutions.py", "start_line": 15, "end_line": 15, "start_col": 0, "end_col": 48, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Annotator_ee834069" }, { "content": "class BaseSolution:\n \"\"\"\n A base class for managing Ultralytics Solutions.\n\n This class provides core functionality for various Ultralytics Solutions, including model loading, object tracking,\n and region initialization. It serves as the foundation for implementing specific computer vision solutions such as\n object counting, pose estimation, and analytics.\n\n Attributes:\n LineString: Class for creating line string geometries from shapely.\n Polygon: Class for creating polygon geometries from shapely.\n Point: Class for creating point geometries from shapely.\n prep: Prepared geometry function from shapely for optimized spatial operations.\n CFG (Dict[str, Any]): Configuration dictionary loaded from YAML file and updated with kwargs.\n LOGGER: Logger instance for solution-specific logging.\n annotator: Annotator instance for drawing on images.\n tracks: YOLO tracking results from the latest inference.\n track_data: Extracted tracking data (boxes or OBB) from tracks.\n boxes (List): Bounding box coordinates from tracking results.\n clss (List[int]): Class indices from tracking results.\n track_ids (List[int]): Track IDs from tracking results.\n confs (List[float]): Confidence scores from tracking results.\n track_line: Current track line for storing tracking history.\n masks: Segmentation masks from tracking results.\n r_s: Region or line geometry object for spatial operations.\n frame_no (int): Current frame number for logging purposes.\n region (List[Tuple[int, int]]): List of coordinate tuples defining region of interest.\n line_width (int): Width of lines used in visualizations.\n model (YOLO): Loaded YOLO model instance.\n names (Dict[int, str]): Dictionary mapping class indices to class names.\n classes (List[int]): List of class indices to track.\n show_conf (bool): Flag to show confidence scores in annotations.\n show_labels (bool): Flag to show class labels in annotations.\n device (str): Device for model inference.\n track_add_args (Dict[str, Any]): Additional arguments for tracking configuration.\n env_check (bool): Flag indicating whether environment supports image display.\n track_history (defaultdict): Dictionary storing tracking history for each object.\n profilers (Tuple): Profiler instances for performance monitoring.\n\n Methods:\n adjust_box_label: Generate formatted label for bounding box.\n extract_tracks: Apply object tracking and extract tracks from input image.\n store_tracking_history: Store object tracking history for given track ID and bounding box.\n initialize_region: Initialize counting region and line segment based on configuration.\n display_output: Display processing results including frames or saved results.\n process: Process method to be implemented by each Solution subclass.\n\n Examples:\n >>> solution = BaseSolution(model=\"yolo11n.pt\", region=[(0, 0), (100, 0), (100, 100), (0, 100)])\n >>> solution.initialize_region()\n >>> image = cv2.imread(\"image.jpg\")\n >>> solution.extract_tracks(image)\n >>> solution.display_output(image)\n \"\"\"\n\n def __init__(self, is_cli: bool = False, **kwargs: Any) -> None:\n \"\"\"\n Initialize the BaseSolution class with configuration settings and YOLO model.\n\n Args:\n is_cli (bool): Enable CLI mode if set to True.\n **kwargs (Any): Additional configuration parameters that override defaults.\n \"\"\"\n self.CFG = vars(SolutionConfig().update(**kwargs))\n self.LOGGER = LOGGER # Store logger object to be used in multiple solution classes\n\n check_requirements(\"shapely>=2.0.0\")\n from shapely.geometry import LineString, Point, Polygon\n from shapely.prepared import prep\n\n self.LineString = LineString\n self.Polygon = Polygon\n self.Point = Point\n self.prep = prep\n self.annotator = None # Initialize annotator\n self.tracks = None\n self.track_data = None\n self.boxes = []\n self.clss = []\n self.track_ids = []\n self.track_line = None\n self.masks = None\n self.r_s = None\n self.frame_no = -1 # Only for logging\n\n self.LOGGER.info(f\"Ultralytics Solutions: ✅ {self.CFG}\")\n self.region = self.CFG[\"region\"] # Store region data for other classes usage\n self.line_width = self.CFG[\"line_width\"]\n\n # Load Model and store additional information (classes, show_conf, show_label)\n if self.CFG[\"model\"] is None:\n self.CFG[\"model\"] = \"yolo11n.pt\"\n self.model = YOLO(self.CFG[\"model\"])\n self.names = self.model.names\n self.classes = self.CFG[\"classes\"]\n self.show_conf = self.CFG[\"show_conf\"]\n self.show_labels = self.CFG[\"show_labels\"]\n self.device = self.CFG[\"device\"]\n\n self.track_add_args = { # Tracker additional arguments for advance configuration\n k: self.CFG[k] for k in {\"iou\", \"conf\", \"device\", \"max_det\", \"half\", \"tracker\"}\n } # verbose must be passed to track method; setting it False in YOLO still logs the track information.\n\n if is_cli and self.CFG[\"source\"] is None:\n d_s = \"solutions_ci_demo.mp4\" if \"-pose\" not in self.CFG[\"model\"] else \"solution_ci_pose_demo.mp4\"\n self.LOGGER.warning(f\"source not provided. using default source {ASSETS_URL}/{d_s}\")\n from ultralytics.utils.downloads import safe_download\n\n safe_download(f\"{ASSETS_URL}/{d_s}\") # download source from ultralytics assets\n self.CFG[\"source\"] = d_s # set default source\n\n # Initialize environment and region setup\n self.env_check = check_imshow(warn=True)\n self.track_history = defaultdict(list)\n\n self.profilers = (\n ops.Profile(device=self.device), # track\n ops.Profile(device=self.device), # solution\n )\n\n def adjust_box_label(self, cls: int, conf: float, track_id: Optional[int] = None) -> Optional[str]:\n \"\"\"\n Generate a formatted label for a bounding box.\n\n This method constructs a label string for a bounding box using the class index and confidence score.\n Optionally includes the track ID if provided. The label format adapts based on the display settings\n defined in `self.show_conf` and `self.show_labels`.\n\n Args:\n cls (int): The class index of the detected object.\n conf (float): The confidence score of the detection.\n track_id (int, optional): The unique identifier for the tracked object.\n\n Returns:\n (str | None): The formatted label string if `self.show_labels` is True; otherwise, None.\n \"\"\"\n name = (\"\" if track_id is None else f\"{track_id} \") + self.names[cls]\n return (f\"{name} {conf:.2f}\" if self.show_conf else name) if self.show_labels else None\n\n def extract_tracks(self, im0: np.ndarray) -> None:\n \"\"\"\n Apply object tracking and extract tracks from an input image or frame.\n\n Args:\n im0 (np.ndarray): The input image or frame.\n\n Examples:\n >>> solution = BaseSolution()\n >>> frame = cv2.imread(\"path/to/image.jpg\")\n >>> solution.extract_tracks(frame)\n \"\"\"\n with self.profilers[0]:\n self.tracks = self.model.track(\n source=im0, persist=True, classes=self.classes, verbose=False, **self.track_add_args\n )[0]\n is_obb = self.tracks.obb is not None\n self.track_data = self.tracks.obb if is_obb else self.tracks.boxes # Extract tracks for OBB or object detection\n\n if self.track_data and self.track_data.is_track:\n self.boxes = (self.track_data.xyxyxyxy if is_obb else self.track_data.xyxy).cpu()\n self.clss = self.track_data.cls.cpu().tolist()\n self.track_ids = self.track_data.id.int().cpu().tolist()\n self.confs = self.track_data.conf.cpu().tolist()\n else:\n self.LOGGER.warning(\"no tracks found!\")\n self.boxes, self.clss, self.track_ids, self.confs = [], [], [], []\n\n def store_tracking_history(self, track_id: int, box) -> None:\n \"\"\"\n Store the tracking history of an object.\n\n This method updates the tracking history for a given object by appending the center point of its\n bounding box to the track line. It maintains a maximum of 30 points in the tracking history.\n\n Args:\n track_id (int): The unique identifier for the tracked object.\n box (List[float]): The bounding box coordinates of the object in the format [x1, y1, x2, y2].\n\n Examples:\n >>> solution = BaseSolution()\n >>> solution.store_tracking_history(1, [100, 200, 300, 400])\n \"\"\"\n # Store tracking history\n self.track_line = self.track_history[track_id]\n self.track_line.append(tuple(box.mean(dim=0)) if box.numel() > 4 else (box[:4:2].mean(), box[1:4:2].mean()))\n if len(self.track_line) > 30:\n self.track_line.pop(0)\n\n def initialize_region(self) -> None:\n \"\"\"Initialize the counting region and line segment based on configuration settings.\"\"\"\n if self.region is None:\n self.region = [(10, 200), (540, 200), (540, 180), (10, 180)]\n self.r_s = (\n self.Polygon(self.region) if len(self.region) >= 3 else self.LineString(self.region)\n ) # region or line\n\n def display_output(self, plot_im: np.ndarray) -> None:\n \"\"\"\n Display the results of the processing, which could involve showing frames, printing counts, or saving results.\n\n This method is responsible for visualizing the output of the object detection and tracking process. It displays\n the processed frame with annotations, and allows for user interaction to close the display.\n\n Args:\n plot_im (np.ndarray): The image or frame that has been processed and annotated.\n\n Examples:\n >>> solution = BaseSolution()\n >>> frame = cv2.imread(\"path/to/image.jpg\")\n >>> solution.display_output(frame)\n\n Notes:\n - This method will only display output if the 'show' configuration is set to True and the environment\n supports image display.\n - The display can be closed by pressing the 'q' key.\n \"\"\"\n if self.CFG.get(\"show\") and self.env_check:\n cv2.imshow(\"Ultralytics Solutions\", plot_im)\n if cv2.waitKey(1) & 0xFF == ord(\"q\"):\n cv2.destroyAllWindows() # Closes current frame window\n return\n\n def process(self, *args: Any, **kwargs: Any):\n \"\"\"Process method should be implemented by each Solution subclass.\"\"\"\n\n def __call__(self, *args: Any, **kwargs: Any):\n \"\"\"Allow instances to be called like a function with flexible arguments.\"\"\"\n with self.profilers[1]:\n result = self.process(*args, **kwargs) # Call the subclass-specific process method\n track_or_predict = \"predict\" if type(self).__name__ == \"ObjectCropper\" else \"track\"\n track_or_predict_speed = self.profilers[0].dt * 1e3\n solution_speed = (self.profilers[1].dt - self.profilers[0].dt) * 1e3 # solution time = process - track\n result.speed = {track_or_predict: track_or_predict_speed, \"solution\": solution_speed}\n if self.CFG[\"verbose\"]:\n self.frame_no += 1\n LOGGER.info(\n f\"{self.frame_no}: {result.plot_im.shape[0]}x{result.plot_im.shape[1]} {solution_speed:.1f}ms\\n\"\n f\"Speed: {track_or_predict_speed:.1f}ms {track_or_predict}, \"\n f\"{solution_speed:.1f}ms solution per image at shape \"\n f\"(1, {getattr(self.model, 'ch', 3)}, {result.plot_im.shape[0]}, {result.plot_im.shape[1]})\\n\"\n )\n return result", "chunk_type": "class", "name": "BaseSolution", "file_path": "ultralytics\\ultralytics\\solutions\\solutions.py", "start_line": 18, "end_line": 259, "start_col": 0, "end_col": 21, "parent_name": null, "docstring": "A base class for managing Ultralytics Solutions.\n\nThis class provides core functionality for various Ultralytics Solutions, including model loading, object tracking,\nand region initialization. It serves as the foundation for implementing specific computer vision solutions such as\nobject counting, pose estimation, and analytics.\n\nAttributes:\n LineString: Class for creating line string geometries from shapely.\n Polygon: Class for creating polygon geometries from shapely.\n Point: Class for creating point geometries from shapely.\n prep: Prepared geometry function from shapely for optimized spatial operations.\n CFG (Dict[str, Any]): Configuration dictionary loaded from YAML file and updated with kwargs.\n LOGGER: Logger instance for solution-specific logging.\n annotator: Annotator instance for drawing on images.\n tracks: YOLO tracking results from the latest inference.\n track_data: Extracted tracking data (boxes or OBB) from tracks.\n boxes (List): Bounding box coordinates from tracking results.\n clss (List[int]): Class indices from tracking results.\n track_ids (List[int]): Track IDs from tracking results.\n confs (List[float]): Confidence scores from tracking results.\n track_line: Current track line for storing tracking history.\n masks: Segmentation masks from tracking results.\n r_s: Region or line geometry object for spatial operations.\n frame_no (int): Current frame number for logging purposes.\n region (List[Tuple[int, int]]): List of coordinate tuples defining region of interest.\n line_width (int): Width of lines used in visualizations.\n model (YOLO): Loaded YOLO model instance.\n names (Dict[int, str]): Dictionary mapping class indices to class names.\n classes (List[int]): List of class indices to track.\n show_conf (bool): Flag to show confidence scores in annotations.\n show_labels (bool): Flag to show class labels in annotations.\n device (str): Device for model inference.\n track_add_args (Dict[str, Any]): Additional arguments for tracking configuration.\n env_check (bool): Flag indicating whether environment supports image display.\n track_history (defaultdict): Dictionary storing tracking history for each object.\n profilers (Tuple): Profiler instances for performance monitoring.\n\nMethods:\n adjust_box_label: Generate formatted label for bounding box.\n extract_tracks: Apply object tracking and extract tracks from input image.\n store_tracking_history: Store object tracking history for given track ID and bounding box.\n initialize_region: Initialize counting region and line segment based on configuration.\n display_output: Display processing results including frames or saved results.\n process: Process method to be implemented by each Solution subclass.\n\nExamples:\n >>> solution = BaseSolution(model=\"yolo11n.pt\", region=[(0, 0), (100, 0), (100, 100), (0, 100)])\n >>> solution.initialize_region()\n >>> image = cv2.imread(\"image.jpg\")\n >>> solution.extract_tracks(image)\n >>> solution.display_output(image)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "math", "collections.defaultdict", "functools.lru_cache", "typing.Any", "typing.Dict", "typing.List", "typing.Optional", "typing.Tuple", "cv2", "numpy", "ultralytics.YOLO", "ultralytics.solutions.config.SolutionConfig", "ultralytics.utils.ASSETS_URL", "ultralytics.utils.LOGGER", "ultralytics.utils.ops", "ultralytics.utils.checks.check_imshow", "ultralytics.utils.checks.check_requirements", "ultralytics.utils.plotting.Annotator", "shapely.geometry.LineString", "shapely.geometry.Point", "shapely.geometry.Polygon", "shapely.prepared.prep", "ultralytics.utils.downloads.safe_download" ], "chunk_id": "class_BaseSolution_d8c7373e" }, { "content": "class SolutionAnnotator(Annotator):\n \"\"\"\n A specialized annotator class for visualizing and analyzing computer vision tasks.\n\n This class extends the base Annotator class, providing additional methods for drawing regions, centroids, tracking\n trails, and visual annotations for Ultralytics Solutions. It offers comprehensive visualization capabilities for\n various computer vision applications including object detection, tracking, pose estimation, and analytics.\n\n Attributes:\n im (np.ndarray): The image being annotated.\n line_width (int): Thickness of lines used in annotations.\n font_size (int): Size of the font used for text annotations.\n font (str): Path to the font file used for text rendering.\n pil (bool): Whether to use PIL for text rendering.\n example (str): An example attribute for demonstration purposes.\n\n Methods:\n draw_region: Draw a region using specified points, colors, and thickness.\n queue_counts_display: Display queue counts in the specified region.\n display_analytics: Display overall statistics for parking lot management.\n estimate_pose_angle: Calculate the angle between three points in an object pose.\n draw_specific_kpts: Draw specific keypoints on the image.\n plot_workout_information: Draw a labeled text box on the image.\n plot_angle_and_count_and_stage: Visualize angle, step count, and stage for workout monitoring.\n plot_distance_and_line: Display the distance between centroids and connect them with a line.\n display_objects_labels: Annotate bounding boxes with object class labels.\n sweep_annotator: Visualize a vertical sweep line and optional label.\n visioneye: Map and connect object centroids to a visual \"eye\" point.\n circle_label: Draw a circular label within a bounding box.\n text_label: Draw a rectangular label within a bounding box.\n\n Examples:\n >>> annotator = SolutionAnnotator(image)\n >>> annotator.draw_region([(0, 0), (100, 100)], color=(0, 255, 0), thickness=5)\n >>> annotator.display_analytics(\n ... image, text={\"Available Spots\": 5}, txt_color=(0, 0, 0), bg_color=(255, 255, 255), margin=10\n ... )\n \"\"\"\n\n def __init__(\n self,\n im: np.ndarray,\n line_width: Optional[int] = None,\n font_size: Optional[int] = None,\n font: str = \"Arial.ttf\",\n pil: bool = False,\n example: str = \"abc\",\n ):\n \"\"\"\n Initialize the SolutionAnnotator class with an image for annotation.\n\n Args:\n im (np.ndarray): The image to be annotated.\n line_width (int, optional): Line thickness for drawing on the image.\n font_size (int, optional): Font size for text annotations.\n font (str): Path to the font file.\n pil (bool): Indicates whether to use PIL for rendering text.\n example (str): An example parameter for demonstration purposes.\n \"\"\"\n super().__init__(im, line_width, font_size, font, pil, example)\n\n def draw_region(\n self,\n reg_pts: Optional[List[Tuple[int, int]]] = None,\n color: Tuple[int, int, int] = (0, 255, 0),\n thickness: int = 5,\n ):\n \"\"\"\n Draw a region or line on the image.\n\n Args:\n reg_pts (List[Tuple[int, int]], optional): Region points (for line 2 points, for region 4+ points).\n color (Tuple[int, int, int]): RGB color value for the region.\n thickness (int): Line thickness for drawing the region.\n \"\"\"\n cv2.polylines(self.im, [np.array(reg_pts, dtype=np.int32)], isClosed=True, color=color, thickness=thickness)\n\n # Draw small circles at the corner points\n for point in reg_pts:\n cv2.circle(self.im, (point[0], point[1]), thickness * 2, color, -1) # -1 fills the circle\n\n def queue_counts_display(\n self,\n label: str,\n points: Optional[List[Tuple[int, int]]] = None,\n region_color: Tuple[int, int, int] = (255, 255, 255),\n txt_color: Tuple[int, int, int] = (0, 0, 0),\n ):\n \"\"\"\n Display queue counts on an image centered at the points with customizable font size and colors.\n\n Args:\n label (str): Queue counts label.\n points (List[Tuple[int, int]], optional): Region points for center point calculation to display text.\n region_color (Tuple[int, int, int]): RGB queue region color.\n txt_color (Tuple[int, int, int]): RGB text display color.\n \"\"\"\n x_values = [point[0] for point in points]\n y_values = [point[1] for point in points]\n center_x = sum(x_values) // len(points)\n center_y = sum(y_values) // len(points)\n\n text_size = cv2.getTextSize(label, 0, fontScale=self.sf, thickness=self.tf)[0]\n text_width = text_size[0]\n text_height = text_size[1]\n\n rect_width = text_width + 20\n rect_height = text_height + 20\n rect_top_left = (center_x - rect_width // 2, center_y - rect_height // 2)\n rect_bottom_right = (center_x + rect_width // 2, center_y + rect_height // 2)\n cv2.rectangle(self.im, rect_top_left, rect_bottom_right, region_color, -1)\n\n text_x = center_x - text_width // 2\n text_y = center_y + text_height // 2\n\n # Draw text\n cv2.putText(\n self.im,\n label,\n (text_x, text_y),\n 0,\n fontScale=self.sf,\n color=txt_color,\n thickness=self.tf,\n lineType=cv2.LINE_AA,\n )\n\n def display_analytics(\n self,\n im0: np.ndarray,\n text: Dict[str, Any],\n txt_color: Tuple[int, int, int],\n bg_color: Tuple[int, int, int],\n margin: int,\n ):\n \"\"\"\n Display the overall statistics for parking lots, object counter etc.\n\n Args:\n im0 (np.ndarray): Inference image.\n text (Dict[str, Any]): Labels dictionary.\n txt_color (Tuple[int, int, int]): Display color for text foreground.\n bg_color (Tuple[int, int, int]): Display color for text background.\n margin (int): Gap between text and rectangle for better display.\n \"\"\"\n horizontal_gap = int(im0.shape[1] * 0.02)\n vertical_gap = int(im0.shape[0] * 0.01)\n text_y_offset = 0\n for label, value in text.items():\n txt = f\"{label}: {value}\"\n text_size = cv2.getTextSize(txt, 0, self.sf, self.tf)[0]\n if text_size[0] < 5 or text_size[1] < 5:\n text_size = (5, 5)\n text_x = im0.shape[1] - text_size[0] - margin * 2 - horizontal_gap\n text_y = text_y_offset + text_size[1] + margin * 2 + vertical_gap\n rect_x1 = text_x - margin * 2\n rect_y1 = text_y - text_size[1] - margin * 2\n rect_x2 = text_x + text_size[0] + margin * 2\n rect_y2 = text_y + margin * 2\n cv2.rectangle(im0, (rect_x1, rect_y1), (rect_x2, rect_y2), bg_color, -1)\n cv2.putText(im0, txt, (text_x, text_y), 0, self.sf, txt_color, self.tf, lineType=cv2.LINE_AA)\n text_y_offset = rect_y2\n\n @staticmethod\n @lru_cache(maxsize=256)\n def estimate_pose_angle(a: List[float], b: List[float], c: List[float]) -> float:\n \"\"\"\n Calculate the angle between three points for workout monitoring.\n\n Args:\n a (List[float]): The coordinates of the first point.\n b (List[float]): The coordinates of the second point (vertex).\n c (List[float]): The coordinates of the third point.\n\n Returns:\n (float): The angle in degrees between the three points.\n \"\"\"\n radians = math.atan2(c[1] - b[1], c[0] - b[0]) - math.atan2(a[1] - b[1], a[0] - b[0])\n angle = abs(radians * 180.0 / math.pi)\n return angle if angle <= 180.0 else (360 - angle)\n\n def draw_specific_kpts(\n self,\n keypoints: List[List[float]],\n indices: Optional[List[int]] = None,\n radius: int = 2,\n conf_thresh: float = 0.25,\n ) -> np.ndarray:\n \"\"\"\n Draw specific keypoints for gym steps counting.\n\n Args:\n keypoints (List[List[float]]): Keypoints data to be plotted, each in format [x, y, confidence].\n indices (List[int], optional): Keypoint indices to be plotted.\n radius (int): Keypoint radius.\n conf_thresh (float): Confidence threshold for keypoints.\n\n Returns:\n (np.ndarray): Image with drawn keypoints.\n\n Notes:\n Keypoint format: [x, y] or [x, y, confidence].\n Modifies self.im in-place.\n \"\"\"\n indices = indices or [2, 5, 7]\n points = [(int(k[0]), int(k[1])) for i, k in enumerate(keypoints) if i in indices and k[2] >= conf_thresh]\n\n # Draw lines between consecutive points\n for start, end in zip(points[:-1], points[1:]):\n cv2.line(self.im, start, end, (0, 255, 0), 2, lineType=cv2.LINE_AA)\n\n # Draw circles for keypoints\n for pt in points:\n cv2.circle(self.im, pt, radius, (0, 0, 255), -1, lineType=cv2.LINE_AA)\n\n return self.im\n\n def plot_workout_information(\n self,\n display_text: str,\n position: Tuple[int, int],\n color: Tuple[int, int, int] = (104, 31, 17),\n txt_color: Tuple[int, int, int] = (255, 255, 255),\n ) -> int:\n \"\"\"\n Draw workout text with a background on the image.\n\n Args:\n display_text (str): The text to be displayed.\n position (Tuple[int, int]): Coordinates (x, y) on the image where the text will be placed.\n color (Tuple[int, int, int]): Text background color.\n txt_color (Tuple[int, int, int]): Text foreground color.\n\n Returns:\n (int): The height of the text.\n \"\"\"\n (text_width, text_height), _ = cv2.getTextSize(display_text, 0, fontScale=self.sf, thickness=self.tf)\n\n # Draw background rectangle\n cv2.rectangle(\n self.im,\n (position[0], position[1] - text_height - 5),\n (position[0] + text_width + 10, position[1] - text_height - 5 + text_height + 10 + self.tf),\n color,\n -1,\n )\n # Draw text\n cv2.putText(self.im, display_text, position, 0, self.sf, txt_color, self.tf)\n\n return text_height\n\n def plot_angle_and_count_and_stage(\n self,\n angle_text: str,\n count_text: str,\n stage_text: str,\n center_kpt: List[int],\n color: Tuple[int, int, int] = (104, 31, 17),\n txt_color: Tuple[int, int, int] = (255, 255, 255),\n ):\n \"\"\"\n Plot the pose angle, count value, and step stage for workout monitoring.\n\n Args:\n angle_text (str): Angle value for workout monitoring.\n count_text (str): Counts value for workout monitoring.\n stage_text (str): Stage decision for workout monitoring.\n center_kpt (List[int]): Centroid pose index for workout monitoring.\n color (Tuple[int, int, int]): Text background color.\n txt_color (Tuple[int, int, int]): Text foreground color.\n \"\"\"\n # Format text\n angle_text, count_text, stage_text = f\" {angle_text:.2f}\", f\"Steps : {count_text}\", f\" {stage_text}\"\n\n # Draw angle, count and stage text\n angle_height = self.plot_workout_information(\n angle_text, (int(center_kpt[0]), int(center_kpt[1])), color, txt_color\n )\n count_height = self.plot_workout_information(\n count_text, (int(center_kpt[0]), int(center_kpt[1]) + angle_height + 20), color, txt_color\n )\n self.plot_workout_information(\n stage_text, (int(center_kpt[0]), int(center_kpt[1]) + angle_height + count_height + 40), color, txt_color\n )\n\n def plot_distance_and_line(\n self,\n pixels_distance: float,\n centroids: List[Tuple[int, int]],\n line_color: Tuple[int, int, int] = (104, 31, 17),\n centroid_color: Tuple[int, int, int] = (255, 0, 255),\n ):\n \"\"\"\n Plot the distance and line between two centroids on the frame.\n\n Args:\n pixels_distance (float): Pixels distance between two bbox centroids.\n centroids (List[Tuple[int, int]]): Bounding box centroids data.\n line_color (Tuple[int, int, int]): Distance line color.\n centroid_color (Tuple[int, int, int]): Bounding box centroid color.\n \"\"\"\n # Get the text size\n text = f\"Pixels Distance: {pixels_distance:.2f}\"\n (text_width_m, text_height_m), _ = cv2.getTextSize(text, 0, self.sf, self.tf)\n\n # Define corners with 10-pixel margin and draw rectangle\n cv2.rectangle(self.im, (15, 25), (15 + text_width_m + 20, 25 + text_height_m + 20), line_color, -1)\n\n # Calculate the position for the text with a 10-pixel margin and draw text\n text_position = (25, 25 + text_height_m + 10)\n cv2.putText(\n self.im,\n text,\n text_position,\n 0,\n self.sf,\n (255, 255, 255),\n self.tf,\n cv2.LINE_AA,\n )\n\n cv2.line(self.im, centroids[0], centroids[1], line_color, 3)\n cv2.circle(self.im, centroids[0], 6, centroid_color, -1)\n cv2.circle(self.im, centroids[1], 6, centroid_color, -1)\n\n def display_objects_labels(\n self,\n im0: np.ndarray,\n text: str,\n txt_color: Tuple[int, int, int],\n bg_color: Tuple[int, int, int],\n x_center: float,\n y_center: float,\n margin: int,\n ):\n \"\"\"\n Display the bounding boxes labels in parking management app.\n\n Args:\n im0 (np.ndarray): Inference image.\n text (str): Object/class name.\n txt_color (Tuple[int, int, int]): Display color for text foreground.\n bg_color (Tuple[int, int, int]): Display color for text background.\n x_center (float): The x position center point for bounding box.\n y_center (float): The y position center point for bounding box.\n margin (int): The gap between text and rectangle for better display.\n \"\"\"\n text_size = cv2.getTextSize(text, 0, fontScale=self.sf, thickness=self.tf)[0]\n text_x = x_center - text_size[0] // 2\n text_y = y_center + text_size[1] // 2\n\n rect_x1 = text_x - margin\n rect_y1 = text_y - text_size[1] - margin\n rect_x2 = text_x + text_size[0] + margin\n rect_y2 = text_y + margin\n cv2.rectangle(\n im0,\n (int(rect_x1), int(rect_y1)),\n (int(rect_x2), int(rect_y2)),\n tuple(map(int, bg_color)), # Ensure color values are int\n -1,\n )\n\n cv2.putText(\n im0,\n text,\n (int(text_x), int(text_y)),\n 0,\n self.sf,\n tuple(map(int, txt_color)), # Ensure color values are int\n self.tf,\n lineType=cv2.LINE_AA,\n )\n\n def sweep_annotator(\n self,\n line_x: int = 0,\n line_y: int = 0,\n label: Optional[str] = None,\n color: Tuple[int, int, int] = (221, 0, 186),\n txt_color: Tuple[int, int, int] = (255, 255, 255),\n ):\n \"\"\"\n Draw a sweep annotation line and an optional label.\n\n Args:\n line_x (int): The x-coordinate of the sweep line.\n line_y (int): The y-coordinate limit of the sweep line.\n label (str, optional): Text label to be drawn in center of sweep line. If None, no label is drawn.\n color (Tuple[int, int, int]): RGB color for the line and label background.\n txt_color (Tuple[int, int, int]): RGB color for the label text.\n \"\"\"\n # Draw the sweep line\n cv2.line(self.im, (line_x, 0), (line_x, line_y), color, self.tf * 2)\n\n # Draw label, if provided\n if label:\n (text_width, text_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, self.sf, self.tf)\n cv2.rectangle(\n self.im,\n (line_x - text_width // 2 - 10, line_y // 2 - text_height // 2 - 10),\n (line_x + text_width // 2 + 10, line_y // 2 + text_height // 2 + 10),\n color,\n -1,\n )\n cv2.putText(\n self.im,\n label,\n (line_x - text_width // 2, line_y // 2 + text_height // 2),\n cv2.FONT_HERSHEY_SIMPLEX,\n self.sf,\n txt_color,\n self.tf,\n )\n\n def visioneye(\n self,\n box: List[float],\n center_point: Tuple[int, int],\n color: Tuple[int, int, int] = (235, 219, 11),\n pin_color: Tuple[int, int, int] = (255, 0, 255),\n ):\n \"\"\"\n Perform pinpoint human-vision eye mapping and plotting.\n\n Args:\n box (List[float]): Bounding box coordinates in format [x1, y1, x2, y2].\n center_point (Tuple[int, int]): Center point for vision eye view.\n color (Tuple[int, int, int]): Object centroid and line color.\n pin_color (Tuple[int, int, int]): Visioneye point color.\n \"\"\"\n center_bbox = int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2)\n cv2.circle(self.im, center_point, self.tf * 2, pin_color, -1)\n cv2.circle(self.im, center_bbox, self.tf * 2, color, -1)\n cv2.line(self.im, center_point, center_bbox, color, self.tf)\n\n def circle_label(\n self,\n box: Tuple[float, float, float, float],\n label: str = \"\",\n color: Tuple[int, int, int] = (128, 128, 128),\n txt_color: Tuple[int, int, int] = (255, 255, 255),\n margin: int = 2,\n ):\n \"\"\"\n Draw a label with a background circle centered within a given bounding box.\n\n Args:\n box (Tuple[float, float, float, float]): The bounding box coordinates (x1, y1, x2, y2).\n label (str): The text label to be displayed.\n color (Tuple[int, int, int]): The background color of the circle (B, G, R).\n txt_color (Tuple[int, int, int]): The color of the text (R, G, B).\n margin (int): The margin between the text and the circle border.\n \"\"\"\n if len(label) > 3:\n LOGGER.warning(f\"Length of label is {len(label)}, only first 3 letters will be used for circle annotation.\")\n label = label[:3]\n\n # Calculate the center of the box\n x_center, y_center = int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2)\n # Get the text size\n text_size = cv2.getTextSize(str(label), cv2.FONT_HERSHEY_SIMPLEX, self.sf - 0.15, self.tf)[0]\n # Calculate the required radius to fit the text with the margin\n required_radius = int(((text_size[0] ** 2 + text_size[1] ** 2) ** 0.5) / 2) + margin\n # Draw the circle with the required radius\n cv2.circle(self.im, (x_center, y_center), required_radius, color, -1)\n # Calculate the position for the text\n text_x = x_center - text_size[0] // 2\n text_y = y_center + text_size[1] // 2\n # Draw the text\n cv2.putText(\n self.im,\n str(label),\n (text_x, text_y),\n cv2.FONT_HERSHEY_SIMPLEX,\n self.sf - 0.15,\n self.get_txt_color(color, txt_color),\n self.tf,\n lineType=cv2.LINE_AA,\n )\n\n def text_label(\n self,\n box: Tuple[float, float, float, float],\n label: str = \"\",\n color: Tuple[int, int, int] = (128, 128, 128),\n txt_color: Tuple[int, int, int] = (255, 255, 255),\n margin: int = 5,\n ):\n \"\"\"\n Draw a label with a background rectangle centered within a given bounding box.\n\n Args:\n box (Tuple[float, float, float, float]): The bounding box coordinates (x1, y1, x2, y2).\n label (str): The text label to be displayed.\n color (Tuple[int, int, int]): The background color of the rectangle (B, G, R).\n txt_color (Tuple[int, int, int]): The color of the text (R, G, B).\n margin (int): The margin between the text and the rectangle border.\n \"\"\"\n # Calculate the center of the bounding box\n x_center, y_center = int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2)\n # Get the size of the text\n text_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, self.sf - 0.1, self.tf)[0]\n # Calculate the top-left corner of the text (to center it)\n text_x = x_center - text_size[0] // 2\n text_y = y_center + text_size[1] // 2\n # Calculate the coordinates of the background rectangle\n rect_x1 = text_x - margin\n rect_y1 = text_y - text_size[1] - margin\n rect_x2 = text_x + text_size[0] + margin\n rect_y2 = text_y + margin\n # Draw the background rectangle\n cv2.rectangle(self.im, (rect_x1, rect_y1), (rect_x2, rect_y2), color, -1)\n # Draw the text on top of the rectangle\n cv2.putText(\n self.im,\n label,\n (text_x, text_y),\n cv2.FONT_HERSHEY_SIMPLEX,\n self.sf - 0.1,\n self.get_txt_color(color, txt_color),\n self.tf,\n lineType=cv2.LINE_AA,\n )", "chunk_type": "class", "name": "SolutionAnnotator", "file_path": "ultralytics\\ultralytics\\solutions\\solutions.py", "start_line": 262, "end_line": 785, "start_col": 0, "end_col": 9, "parent_name": null, "docstring": "A specialized annotator class for visualizing and analyzing computer vision tasks.\n\nThis class extends the base Annotator class, providing additional methods for drawing regions, centroids, tracking\ntrails, and visual annotations for Ultralytics Solutions. It offers comprehensive visualization capabilities for\nvarious computer vision applications including object detection, tracking, pose estimation, and analytics.\n\nAttributes:\n im (np.ndarray): The image being annotated.\n line_width (int): Thickness of lines used in annotations.\n font_size (int): Size of the font used for text annotations.\n font (str): Path to the font file used for text rendering.\n pil (bool): Whether to use PIL for text rendering.\n example (str): An example attribute for demonstration purposes.\n\nMethods:\n draw_region: Draw a region using specified points, colors, and thickness.\n queue_counts_display: Display queue counts in the specified region.\n display_analytics: Display overall statistics for parking lot management.\n estimate_pose_angle: Calculate the angle between three points in an object pose.\n draw_specific_kpts: Draw specific keypoints on the image.\n plot_workout_information: Draw a labeled text box on the image.\n plot_angle_and_count_and_stage: Visualize angle, step count, and stage for workout monitoring.\n plot_distance_and_line: Display the distance between centroids and connect them with a line.\n display_objects_labels: Annotate bounding boxes with object class labels.\n sweep_annotator: Visualize a vertical sweep line and optional label.\n visioneye: Map and connect object centroids to a visual \"eye\" point.\n circle_label: Draw a circular label within a bounding box.\n text_label: Draw a rectangular label within a bounding box.\n\nExamples:\n >>> annotator = SolutionAnnotator(image)\n >>> annotator.draw_region([(0, 0), (100, 100)], color=(0, 255, 0), thickness=5)\n >>> annotator.display_analytics(\n ... image, text={\"Available Spots\": 5}, txt_color=(0, 0, 0), bg_color=(255, 255, 255), margin=10\n ... )", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "math", "collections.defaultdict", "functools.lru_cache", "typing.Any", "typing.Dict", "typing.List", "typing.Optional", "typing.Tuple", "cv2", "numpy", "ultralytics.YOLO", "ultralytics.solutions.config.SolutionConfig", "ultralytics.utils.ASSETS_URL", "ultralytics.utils.LOGGER", "ultralytics.utils.ops", "ultralytics.utils.checks.check_imshow", "ultralytics.utils.checks.check_requirements", "ultralytics.utils.plotting.Annotator", "shapely.geometry.LineString", "shapely.geometry.Point", "shapely.geometry.Polygon", "shapely.prepared.prep", "ultralytics.utils.downloads.safe_download", "Annotator" ], "chunk_id": "class_SolutionAnnotator_faab5a5d" }, { "content": "class SolutionResults:\n \"\"\"\n A class to encapsulate the results of Ultralytics Solutions.\n\n This class is designed to store and manage various outputs generated by the solution pipeline, including counts,\n angles, workout stages, and other analytics data. It provides a structured way to access and manipulate results\n from different computer vision solutions such as object counting, pose estimation, and tracking analytics.\n\n Attributes:\n plot_im (np.ndarray): Processed image with counts, blurred, or other effects from solutions.\n in_count (int): The total number of \"in\" counts in a video stream.\n out_count (int): The total number of \"out\" counts in a video stream.\n classwise_count (Dict[str, int]): A dictionary containing counts of objects categorized by class.\n queue_count (int): The count of objects in a queue or waiting area.\n workout_count (int): The count of workout repetitions.\n workout_angle (float): The angle calculated during a workout exercise.\n workout_stage (str): The current stage of the workout.\n pixels_distance (float): The calculated distance in pixels between two points or objects.\n available_slots (int): The number of available slots in a monitored area.\n filled_slots (int): The number of filled slots in a monitored area.\n email_sent (bool): A flag indicating whether an email notification was sent.\n total_tracks (int): The total number of tracked objects.\n region_counts (Dict[str, int]): The count of objects within a specific region.\n speed_dict (Dict[str, float]): A dictionary containing speed information for tracked objects.\n total_crop_objects (int): Total number of cropped objects using ObjectCropper class.\n speed (Dict[str, float]): Performance timing information for tracking and solution processing.\n \"\"\"\n\n def __init__(self, **kwargs):\n \"\"\"\n Initialize a SolutionResults object with default or user-specified values.\n\n Args:\n **kwargs (Any): Optional arguments to override default attribute values.\n \"\"\"\n self.plot_im = None\n self.in_count = 0\n self.out_count = 0\n self.classwise_count = {}\n self.queue_count = 0\n self.workout_count = 0\n self.workout_angle = 0.0\n self.workout_stage = None\n self.pixels_distance = 0.0\n self.available_slots = 0\n self.filled_slots = 0\n self.email_sent = False\n self.total_tracks = 0\n self.region_counts = {}\n self.speed_dict = {} # for speed estimation\n self.total_crop_objects = 0\n self.speed = {}\n\n # Override with user-defined values\n self.__dict__.update(kwargs)\n\n def __str__(self) -> str:\n \"\"\"\n Return a formatted string representation of the SolutionResults object.\n\n Returns:\n (str): A string representation listing non-null attributes.\n \"\"\"\n attrs = {\n k: v\n for k, v in self.__dict__.items()\n if k != \"plot_im\" and v not in [None, {}, 0, 0.0, False] # Exclude `plot_im` explicitly\n }\n return \", \".join(f\"{k}={v}\" for k, v in attrs.items())", "chunk_type": "class", "name": "SolutionResults", "file_path": "ultralytics\\ultralytics\\solutions\\solutions.py", "start_line": 788, "end_line": 856, "start_col": 0, "end_col": 62, "parent_name": null, "docstring": "A class to encapsulate the results of Ultralytics Solutions.\n\nThis class is designed to store and manage various outputs generated by the solution pipeline, including counts,\nangles, workout stages, and other analytics data. It provides a structured way to access and manipulate results\nfrom different computer vision solutions such as object counting, pose estimation, and tracking analytics.\n\nAttributes:\n plot_im (np.ndarray): Processed image with counts, blurred, or other effects from solutions.\n in_count (int): The total number of \"in\" counts in a video stream.\n out_count (int): The total number of \"out\" counts in a video stream.\n classwise_count (Dict[str, int]): A dictionary containing counts of objects categorized by class.\n queue_count (int): The count of objects in a queue or waiting area.\n workout_count (int): The count of workout repetitions.\n workout_angle (float): The angle calculated during a workout exercise.\n workout_stage (str): The current stage of the workout.\n pixels_distance (float): The calculated distance in pixels between two points or objects.\n available_slots (int): The number of available slots in a monitored area.\n filled_slots (int): The number of filled slots in a monitored area.\n email_sent (bool): A flag indicating whether an email notification was sent.\n total_tracks (int): The total number of tracked objects.\n region_counts (Dict[str, int]): The count of objects within a specific region.\n speed_dict (Dict[str, float]): A dictionary containing speed information for tracked objects.\n total_crop_objects (int): Total number of cropped objects using ObjectCropper class.\n speed (Dict[str, float]): Performance timing information for tracking and solution processing.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "math", "collections.defaultdict", "functools.lru_cache", "typing.Any", "typing.Dict", "typing.List", "typing.Optional", "typing.Tuple", "cv2", "numpy", "ultralytics.YOLO", "ultralytics.solutions.config.SolutionConfig", "ultralytics.utils.ASSETS_URL", "ultralytics.utils.LOGGER", "ultralytics.utils.ops", "ultralytics.utils.checks.check_imshow", "ultralytics.utils.checks.check_requirements", "ultralytics.utils.plotting.Annotator", "shapely.geometry.LineString", "shapely.geometry.Point", "shapely.geometry.Polygon", "shapely.prepared.prep", "ultralytics.utils.downloads.safe_download" ], "chunk_id": "class_SolutionResults_66279a46" }, { "content": "from collections import deque", "chunk_type": "import", "name": "deque", "file_path": "ultralytics\\ultralytics\\solutions\\speed_estimation.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 29, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_deque_a826132f" }, { "content": "from math import sqrt", "chunk_type": "import", "name": "sqrt", "file_path": "ultralytics\\ultralytics\\solutions\\speed_estimation.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 21, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_sqrt_d4a1e012" }, { "content": "from typing import Any", "chunk_type": "import", "name": "Any", "file_path": "ultralytics\\ultralytics\\solutions\\speed_estimation.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 22, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Any_b82552cb" }, { "content": "from ultralytics.solutions.solutions import BaseSolution, SolutionAnnotator, SolutionResults", "chunk_type": "import", "name": "BaseSolution, SolutionAnnotator, SolutionResults", "file_path": "ultralytics\\ultralytics\\solutions\\speed_estimation.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 92, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_BaseSolution, SolutionAnnotator, SolutionResults_a1387d74" }, { "content": "from ultralytics.utils.plotting import colors", "chunk_type": "import", "name": "colors", "file_path": "ultralytics\\ultralytics\\solutions\\speed_estimation.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 45, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_colors_7dec7666" }, { "content": "class SpeedEstimator(BaseSolution):\n \"\"\"\n A class to estimate the speed of objects in a real-time video stream based on their tracks.\n\n This class extends the BaseSolution class and provides functionality for estimating object speeds using\n tracking data in video streams. Speed is calculated based on pixel displacement over time and converted\n to real-world units using a configurable meters-per-pixel scale factor.\n\n Attributes:\n fps (float): Video frame rate for time calculations.\n frame_count (int): Global frame counter for tracking temporal information.\n trk_frame_ids (dict): Maps track IDs to their first frame index.\n spd (dict): Final speed per object in km/h once locked.\n trk_hist (dict): Maps track IDs to deque of position history.\n locked_ids (set): Track IDs whose speed has been finalized.\n max_hist (int): Required frame history before computing speed.\n meter_per_pixel (float): Real-world meters represented by one pixel for scene scale conversion.\n max_speed (int): Maximum allowed object speed; values above this will be capped.\n\n Methods:\n process: Process input frames to estimate object speeds based on tracking data.\n store_tracking_history: Store the tracking history for an object.\n extract_tracks: Extract tracks from the current frame.\n display_output: Display the output with annotations.\n\n Examples:\n Initialize speed estimator and process a frame\n >>> estimator = SpeedEstimator(meter_per_pixel=0.04, max_speed=120)\n >>> frame = cv2.imread(\"frame.jpg\")\n >>> results = estimator.process(frame)\n >>> cv2.imshow(\"Speed Estimation\", results.plot_im)\n \"\"\"\n\n def __init__(self, **kwargs: Any) -> None:\n \"\"\"\n Initialize the SpeedEstimator object with speed estimation parameters and data structures.\n\n Args:\n **kwargs (Any): Additional keyword arguments passed to the parent class.\n \"\"\"\n super().__init__(**kwargs)\n\n self.fps = self.CFG[\"fps\"] # Video frame rate for time calculations\n self.frame_count = 0 # Global frame counter\n self.trk_frame_ids = {} # Track ID → first frame index\n self.spd = {} # Final speed per object (km/h), once locked\n self.trk_hist = {} # Track ID → deque of (time, position)\n self.locked_ids = set() # Track IDs whose speed has been finalized\n self.max_hist = self.CFG[\"max_hist\"] # Required frame history before computing speed\n self.meter_per_pixel = self.CFG[\"meter_per_pixel\"] # Scene scale, depends on camera details\n self.max_speed = self.CFG[\"max_speed\"] # Maximum speed adjustment\n\n def process(self, im0) -> SolutionResults:\n \"\"\"\n Process an input frame to estimate object speeds based on tracking data.\n\n Args:\n im0 (np.ndarray): Input image for processing with shape (H, W, C) for RGB images.\n\n Returns:\n (SolutionResults): Contains processed image `plot_im` and `total_tracks` (number of tracked objects).\n\n Examples:\n Process a frame for speed estimation\n >>> estimator = SpeedEstimator()\n >>> image = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8)\n >>> results = estimator.process(image)\n \"\"\"\n self.frame_count += 1\n self.extract_tracks(im0)\n annotator = SolutionAnnotator(im0, line_width=self.line_width)\n\n for box, track_id, _, _ in zip(self.boxes, self.track_ids, self.clss, self.confs):\n self.store_tracking_history(track_id, box)\n\n if track_id not in self.trk_hist: # Initialize history if new track found\n self.trk_hist[track_id] = deque(maxlen=self.max_hist)\n self.trk_frame_ids[track_id] = self.frame_count\n\n if track_id not in self.locked_ids: # Update history until speed is locked\n trk_hist = self.trk_hist[track_id]\n trk_hist.append(self.track_line[-1])\n\n # Compute and lock speed once enough history is collected\n if len(trk_hist) == self.max_hist:\n p0, p1 = trk_hist[0], trk_hist[-1] # First and last points of track\n dt = (self.frame_count - self.trk_frame_ids[track_id]) / self.fps # Time in seconds\n if dt > 0:\n dx, dy = p1[0] - p0[0], p1[1] - p0[1] # Pixel displacement\n pixel_distance = sqrt(dx * dx + dy * dy) # Calculate pixel distance\n meters = pixel_distance * self.meter_per_pixel # Convert to meters\n self.spd[track_id] = int(\n min((meters / dt) * 3.6, self.max_speed)\n ) # Convert to km/h and store final speed\n self.locked_ids.add(track_id) # Prevent further updates\n self.trk_hist.pop(track_id, None) # Free memory\n self.trk_frame_ids.pop(track_id, None) # Remove frame start reference\n\n if track_id in self.spd:\n speed_label = f\"{self.spd[track_id]} km/h\"\n annotator.box_label(box, label=speed_label, color=colors(track_id, True)) # Draw bounding box\n\n plot_im = annotator.result()\n self.display_output(plot_im) # Display output with base class function\n\n # Return results with processed image and tracking summary\n return SolutionResults(plot_im=plot_im, total_tracks=len(self.track_ids))", "chunk_type": "class", "name": "SpeedEstimator", "file_path": "ultralytics\\ultralytics\\solutions\\speed_estimation.py", "start_line": 11, "end_line": 117, "start_col": 0, "end_col": 81, "parent_name": null, "docstring": "A class to estimate the speed of objects in a real-time video stream based on their tracks.\n\nThis class extends the BaseSolution class and provides functionality for estimating object speeds using\ntracking data in video streams. Speed is calculated based on pixel displacement over time and converted\nto real-world units using a configurable meters-per-pixel scale factor.\n\nAttributes:\n fps (float): Video frame rate for time calculations.\n frame_count (int): Global frame counter for tracking temporal information.\n trk_frame_ids (dict): Maps track IDs to their first frame index.\n spd (dict): Final speed per object in km/h once locked.\n trk_hist (dict): Maps track IDs to deque of position history.\n locked_ids (set): Track IDs whose speed has been finalized.\n max_hist (int): Required frame history before computing speed.\n meter_per_pixel (float): Real-world meters represented by one pixel for scene scale conversion.\n max_speed (int): Maximum allowed object speed; values above this will be capped.\n\nMethods:\n process: Process input frames to estimate object speeds based on tracking data.\n store_tracking_history: Store the tracking history for an object.\n extract_tracks: Extract tracks from the current frame.\n display_output: Display the output with annotations.\n\nExamples:\n Initialize speed estimator and process a frame\n >>> estimator = SpeedEstimator(meter_per_pixel=0.04, max_speed=120)\n >>> frame = cv2.imread(\"frame.jpg\")\n >>> results = estimator.process(frame)\n >>> cv2.imshow(\"Speed Estimation\", results.plot_im)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "collections.deque", "math.sqrt", "typing.Any", "ultralytics.solutions.solutions.BaseSolution", "ultralytics.solutions.solutions.SolutionAnnotator", "ultralytics.solutions.solutions.SolutionResults", "ultralytics.utils.plotting.colors", "BaseSolution" ], "chunk_id": "class_SpeedEstimator_c026ecda" }, { "content": "import io", "chunk_type": "import", "name": "io", "file_path": "ultralytics\\ultralytics\\solutions\\streamlit_inference.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 9, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_io_58d2a448" }, { "content": "from typing import Any, List", "chunk_type": "import", "name": "Any, List", "file_path": "ultralytics\\ultralytics\\solutions\\streamlit_inference.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 28, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Any, List_035f5a5b" }, { "content": "import cv2", "chunk_type": "import", "name": "cv2", "file_path": "ultralytics\\ultralytics\\solutions\\streamlit_inference.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 10, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_cv2_4567a341" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\ultralytics\\solutions\\streamlit_inference.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_b8239f7b" }, { "content": "from ultralytics import YOLO", "chunk_type": "import", "name": "YOLO", "file_path": "ultralytics\\ultralytics\\solutions\\streamlit_inference.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 28, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_YOLO_2bd34118" }, { "content": "from ultralytics.utils import LOGGER", "chunk_type": "import", "name": "LOGGER", "file_path": "ultralytics\\ultralytics\\solutions\\streamlit_inference.py", "start_line": 10, "end_line": 10, "start_col": 0, "end_col": 36, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_LOGGER_47f3cb60" }, { "content": "from ultralytics.utils.checks import check_requirements", "chunk_type": "import", "name": "check_requirements", "file_path": "ultralytics\\ultralytics\\solutions\\streamlit_inference.py", "start_line": 11, "end_line": 11, "start_col": 0, "end_col": 55, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_check_requirements_e1938d5b" }, { "content": "from ultralytics.utils.downloads import GITHUB_ASSETS_STEMS", "chunk_type": "import", "name": "GITHUB_ASSETS_STEMS", "file_path": "ultralytics\\ultralytics\\solutions\\streamlit_inference.py", "start_line": 12, "end_line": 12, "start_col": 0, "end_col": 59, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_GITHUB_ASSETS_STEMS_899663f0" }, { "content": "class Inference:\n \"\"\"\n A class to perform object detection, image classification, image segmentation and pose estimation inference.\n\n This class provides functionalities for loading models, configuring settings, uploading video files, and performing\n real-time inference using Streamlit and Ultralytics YOLO models.\n\n Attributes:\n st (module): Streamlit module for UI creation.\n temp_dict (dict): Temporary dictionary to store the model path and other configuration.\n model_path (str): Path to the loaded model.\n model (YOLO): The YOLO model instance.\n source (str): Selected video source (webcam or video file).\n enable_trk (bool): Enable tracking option.\n conf (float): Confidence threshold for detection.\n iou (float): IoU threshold for non-maximum suppression.\n org_frame (Any): Container for the original frame to be displayed.\n ann_frame (Any): Container for the annotated frame to be displayed.\n vid_file_name (str | int): Name of the uploaded video file or webcam index.\n selected_ind (List[int]): List of selected class indices for detection.\n\n Methods:\n web_ui: Set up the Streamlit web interface with custom HTML elements.\n sidebar: Configure the Streamlit sidebar for model and inference settings.\n source_upload: Handle video file uploads through the Streamlit interface.\n configure: Configure the model and load selected classes for inference.\n inference: Perform real-time object detection inference.\n\n Examples:\n Create an Inference instance with a custom model\n >>> inf = Inference(model=\"path/to/model.pt\")\n >>> inf.inference()\n\n Create an Inference instance with default settings\n >>> inf = Inference()\n >>> inf.inference()\n \"\"\"\n\n def __init__(self, **kwargs: Any) -> None:\n \"\"\"\n Initialize the Inference class, checking Streamlit requirements and setting up the model path.\n\n Args:\n **kwargs (Any): Additional keyword arguments for model configuration.\n \"\"\"\n check_requirements(\"streamlit>=1.29.0\") # scope imports for faster ultralytics package load speeds\n import streamlit as st\n\n self.st = st # Reference to the Streamlit module\n self.source = None # Video source selection (webcam or video file)\n self.enable_trk = False # Flag to toggle object tracking\n self.conf = 0.25 # Confidence threshold for detection\n self.iou = 0.45 # Intersection-over-Union (IoU) threshold for non-maximum suppression\n self.org_frame = None # Container for the original frame display\n self.ann_frame = None # Container for the annotated frame display\n self.vid_file_name = None # Video file name or webcam index\n self.selected_ind: List[int] = [] # List of selected class indices for detection\n self.model = None # YOLO model instance\n\n self.temp_dict = {\"model\": None, **kwargs}\n self.model_path = None # Model file path\n if self.temp_dict[\"model\"] is not None:\n self.model_path = self.temp_dict[\"model\"]\n\n LOGGER.info(f\"Ultralytics Solutions: ✅ {self.temp_dict}\")\n\n def web_ui(self) -> None:\n \"\"\"Set up the Streamlit web interface with custom HTML elements.\"\"\"\n menu_style_cfg = \"\"\"\"\"\" # Hide main menu style\n\n # Main title of streamlit application\n main_title_cfg = \"\"\"

Ultralytics YOLO Streamlit Application

\"\"\"\n\n # Subtitle of streamlit application\n sub_title_cfg = \"\"\"

Experience real-time object detection on your webcam with the power \n of Ultralytics YOLO! 🚀

\"\"\"\n\n # Set html page configuration and append custom HTML\n self.st.set_page_config(page_title=\"Ultralytics Streamlit App\", layout=\"wide\")\n self.st.markdown(menu_style_cfg, unsafe_allow_html=True)\n self.st.markdown(main_title_cfg, unsafe_allow_html=True)\n self.st.markdown(sub_title_cfg, unsafe_allow_html=True)\n\n def sidebar(self) -> None:\n \"\"\"Configure the Streamlit sidebar for model and inference settings.\"\"\"\n with self.st.sidebar: # Add Ultralytics LOGO\n logo = \"https://raw.githubusercontent.com/ultralytics/assets/main/logo/Ultralytics_Logotype_Original.svg\"\n self.st.image(logo, width=250)\n\n self.st.sidebar.title(\"User Configuration\") # Add elements to vertical setting menu\n self.source = self.st.sidebar.selectbox(\n \"Video\",\n (\"webcam\", \"video\"),\n ) # Add source selection dropdown\n self.enable_trk = self.st.sidebar.radio(\"Enable Tracking\", (\"Yes\", \"No\")) == \"Yes\" # Enable object tracking\n self.conf = float(\n self.st.sidebar.slider(\"Confidence Threshold\", 0.0, 1.0, self.conf, 0.01)\n ) # Slider for confidence\n self.iou = float(self.st.sidebar.slider(\"IoU Threshold\", 0.0, 1.0, self.iou, 0.01)) # Slider for NMS threshold\n\n col1, col2 = self.st.columns(2) # Create two columns for displaying frames\n self.org_frame = col1.empty() # Container for original frame\n self.ann_frame = col2.empty() # Container for annotated frame\n\n def source_upload(self) -> None:\n \"\"\"Handle video file uploads through the Streamlit interface.\"\"\"\n self.vid_file_name = \"\"\n if self.source == \"video\":\n vid_file = self.st.sidebar.file_uploader(\"Upload Video File\", type=[\"mp4\", \"mov\", \"avi\", \"mkv\"])\n if vid_file is not None:\n g = io.BytesIO(vid_file.read()) # BytesIO Object\n with open(\"ultralytics.mp4\", \"wb\") as out: # Open temporary file as bytes\n out.write(g.read()) # Read bytes into file\n self.vid_file_name = \"ultralytics.mp4\"\n elif self.source == \"webcam\":\n self.vid_file_name = 0 # Use webcam index 0\n\n def configure(self) -> None:\n \"\"\"Configure the model and load selected classes for inference.\"\"\"\n # Add dropdown menu for model selection\n M_ORD, T_ORD = [\"yolo11n\", \"yolo11s\", \"yolo11m\", \"yolo11l\", \"yolo11x\"], [\"\", \"-seg\", \"-pose\", \"-obb\", \"-cls\"]\n available_models = sorted(\n [\n x.replace(\"yolo\", \"YOLO\")\n for x in GITHUB_ASSETS_STEMS\n if any(x.startswith(b) for b in M_ORD) and \"grayscale\" not in x\n ],\n key=lambda x: (M_ORD.index(x[:7].lower()), T_ORD.index(x[7:].lower() or \"\")),\n )\n if self.model_path: # If user provided the custom model, insert model without suffix as *.pt is added later\n available_models.insert(0, self.model_path.split(\".pt\", 1)[0])\n selected_model = self.st.sidebar.selectbox(\"Model\", available_models)\n\n with self.st.spinner(\"Model is downloading...\"):\n self.model = YOLO(f\"{selected_model.lower()}.pt\") # Load the YOLO model\n class_names = list(self.model.names.values()) # Convert dictionary to list of class names\n self.st.success(\"Model loaded successfully!\")\n\n # Multiselect box with class names and get indices of selected classes\n selected_classes = self.st.sidebar.multiselect(\"Classes\", class_names, default=class_names[:3])\n self.selected_ind = [class_names.index(option) for option in selected_classes]\n\n if not isinstance(self.selected_ind, list): # Ensure selected_options is a list\n self.selected_ind = list(self.selected_ind)\n\n def inference(self) -> None:\n \"\"\"Perform real-time object detection inference on video or webcam feed.\"\"\"\n self.web_ui() # Initialize the web interface\n self.sidebar() # Create the sidebar\n self.source_upload() # Upload the video source\n self.configure() # Configure the app\n\n if self.st.sidebar.button(\"Start\"):\n stop_button = self.st.button(\"Stop\") # Button to stop the inference\n cap = cv2.VideoCapture(self.vid_file_name) # Capture the video\n if not cap.isOpened():\n self.st.error(\"Could not open webcam or video source.\")\n return\n\n while cap.isOpened():\n success, frame = cap.read()\n if not success:\n self.st.warning(\"Failed to read frame from webcam. Please verify the webcam is connected properly.\")\n break\n\n # Process frame with model\n if self.enable_trk:\n results = self.model.track(\n frame, conf=self.conf, iou=self.iou, classes=self.selected_ind, persist=True\n )\n else:\n results = self.model(frame, conf=self.conf, iou=self.iou, classes=self.selected_ind)\n\n annotated_frame = results[0].plot() # Add annotations on frame\n\n if stop_button:\n cap.release() # Release the capture\n self.st.stop() # Stop streamlit app\n\n self.org_frame.image(frame, channels=\"BGR\") # Display original frame\n self.ann_frame.image(annotated_frame, channels=\"BGR\") # Display processed frame\n\n cap.release() # Release the capture\n cv2.destroyAllWindows() # Destroy all OpenCV windows", "chunk_type": "class", "name": "Inference", "file_path": "ultralytics\\ultralytics\\solutions\\streamlit_inference.py", "start_line": 17, "end_line": 202, "start_col": 0, "end_col": 31, "parent_name": null, "docstring": "A class to perform object detection, image classification, image segmentation and pose estimation inference.\n\nThis class provides functionalities for loading models, configuring settings, uploading video files, and performing\nreal-time inference using Streamlit and Ultralytics YOLO models.\n\nAttributes:\n st (module): Streamlit module for UI creation.\n temp_dict (dict): Temporary dictionary to store the model path and other configuration.\n model_path (str): Path to the loaded model.\n model (YOLO): The YOLO model instance.\n source (str): Selected video source (webcam or video file).\n enable_trk (bool): Enable tracking option.\n conf (float): Confidence threshold for detection.\n iou (float): IoU threshold for non-maximum suppression.\n org_frame (Any): Container for the original frame to be displayed.\n ann_frame (Any): Container for the annotated frame to be displayed.\n vid_file_name (str | int): Name of the uploaded video file or webcam index.\n selected_ind (List[int]): List of selected class indices for detection.\n\nMethods:\n web_ui: Set up the Streamlit web interface with custom HTML elements.\n sidebar: Configure the Streamlit sidebar for model and inference settings.\n source_upload: Handle video file uploads through the Streamlit interface.\n configure: Configure the model and load selected classes for inference.\n inference: Perform real-time object detection inference.\n\nExamples:\n Create an Inference instance with a custom model\n >>> inf = Inference(model=\"path/to/model.pt\")\n >>> inf.inference()\n\n Create an Inference instance with default settings\n >>> inf = Inference()\n >>> inf.inference()", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "io", "typing.Any", "typing.List", "cv2", "torch", "ultralytics.YOLO", "ultralytics.utils.LOGGER", "ultralytics.utils.checks.check_requirements", "ultralytics.utils.downloads.GITHUB_ASSETS_STEMS", "sys", "streamlit" ], "chunk_id": "class_Inference_9c8db937" }, { "content": "from typing import Any", "chunk_type": "import", "name": "Any", "file_path": "ultralytics\\ultralytics\\solutions\\trackzone.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 22, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Any_1cb3a850" }, { "content": "import cv2", "chunk_type": "import", "name": "cv2", "file_path": "ultralytics\\ultralytics\\solutions\\trackzone.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 10, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_cv2_ce09cae2" }, { "content": "import numpy as np", "chunk_type": "import", "name": "numpy", "file_path": "ultralytics\\ultralytics\\solutions\\trackzone.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_numpy_87cb7acd" }, { "content": "from ultralytics.solutions.solutions import BaseSolution, SolutionAnnotator, SolutionResults", "chunk_type": "import", "name": "BaseSolution, SolutionAnnotator, SolutionResults", "file_path": "ultralytics\\ultralytics\\solutions\\trackzone.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 92, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_BaseSolution, SolutionAnnotator, SolutionResults_8b131528" }, { "content": "from ultralytics.utils.plotting import colors", "chunk_type": "import", "name": "colors", "file_path": "ultralytics\\ultralytics\\solutions\\trackzone.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 45, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_colors_ecd6be4a" }, { "content": "class TrackZone(BaseSolution):\n \"\"\"\n A class to manage region-based object tracking in a video stream.\n\n This class extends the BaseSolution class and provides functionality for tracking objects within a specific region\n defined by a polygonal area. Objects outside the region are excluded from tracking.\n\n Attributes:\n region (np.ndarray): The polygonal region for tracking, represented as a convex hull of points.\n line_width (int): Width of the lines used for drawing bounding boxes and region boundaries.\n names (List[str]): List of class names that the model can detect.\n boxes (List[np.ndarray]): Bounding boxes of tracked objects.\n track_ids (List[int]): Unique identifiers for each tracked object.\n clss (List[int]): Class indices of tracked objects.\n\n Methods:\n process: Process each frame of the video, applying region-based tracking.\n extract_tracks: Extract tracking information from the input frame.\n display_output: Display the processed output.\n\n Examples:\n >>> tracker = TrackZone()\n >>> frame = cv2.imread(\"frame.jpg\")\n >>> results = tracker.process(frame)\n >>> cv2.imshow(\"Tracked Frame\", results.plot_im)\n \"\"\"\n\n def __init__(self, **kwargs: Any) -> None:\n \"\"\"\n Initialize the TrackZone class for tracking objects within a defined region in video streams.\n\n Args:\n **kwargs (Any): Additional keyword arguments passed to the parent class.\n \"\"\"\n super().__init__(**kwargs)\n default_region = [(75, 75), (565, 75), (565, 285), (75, 285)]\n self.region = cv2.convexHull(np.array(self.region or default_region, dtype=np.int32))\n self.mask = None\n\n def process(self, im0: np.ndarray) -> SolutionResults:\n \"\"\"\n Process the input frame to track objects within a defined region.\n\n This method initializes the annotator, creates a mask for the specified region, extracts tracks\n only from the masked area, and updates tracking information. Objects outside the region are ignored.\n\n Args:\n im0 (np.ndarray): The input image or frame to be processed.\n\n Returns:\n (SolutionResults): Contains processed image `plot_im` and `total_tracks` (int) representing the\n total number of tracked objects within the defined region.\n\n Examples:\n >>> tracker = TrackZone()\n >>> frame = cv2.imread(\"path/to/image.jpg\")\n >>> results = tracker.process(frame)\n \"\"\"\n annotator = SolutionAnnotator(im0, line_width=self.line_width) # Initialize annotator\n\n if self.mask is None: # Create a mask for the region\n self.mask = np.zeros_like(im0[:, :, 0])\n cv2.fillPoly(self.mask, [self.region], 255)\n masked_frame = cv2.bitwise_and(im0, im0, mask=self.mask)\n self.extract_tracks(masked_frame)\n\n # Draw the region boundary\n cv2.polylines(im0, [self.region], isClosed=True, color=(255, 255, 255), thickness=self.line_width * 2)\n\n # Iterate over boxes, track ids, classes indexes list and draw bounding boxes\n for box, track_id, cls, conf in zip(self.boxes, self.track_ids, self.clss, self.confs):\n annotator.box_label(\n box, label=self.adjust_box_label(cls, conf, track_id=track_id), color=colors(track_id, True)\n )\n\n plot_im = annotator.result()\n self.display_output(plot_im) # Display output with base class function\n\n # Return a SolutionResults\n return SolutionResults(plot_im=plot_im, total_tracks=len(self.track_ids))", "chunk_type": "class", "name": "TrackZone", "file_path": "ultralytics\\ultralytics\\solutions\\trackzone.py", "start_line": 12, "end_line": 91, "start_col": 0, "end_col": 81, "parent_name": null, "docstring": "A class to manage region-based object tracking in a video stream.\n\nThis class extends the BaseSolution class and provides functionality for tracking objects within a specific region\ndefined by a polygonal area. Objects outside the region are excluded from tracking.\n\nAttributes:\n region (np.ndarray): The polygonal region for tracking, represented as a convex hull of points.\n line_width (int): Width of the lines used for drawing bounding boxes and region boundaries.\n names (List[str]): List of class names that the model can detect.\n boxes (List[np.ndarray]): Bounding boxes of tracked objects.\n track_ids (List[int]): Unique identifiers for each tracked object.\n clss (List[int]): Class indices of tracked objects.\n\nMethods:\n process: Process each frame of the video, applying region-based tracking.\n extract_tracks: Extract tracking information from the input frame.\n display_output: Display the processed output.\n\nExamples:\n >>> tracker = TrackZone()\n >>> frame = cv2.imread(\"frame.jpg\")\n >>> results = tracker.process(frame)\n >>> cv2.imshow(\"Tracked Frame\", results.plot_im)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.Any", "cv2", "numpy", "ultralytics.solutions.solutions.BaseSolution", "ultralytics.solutions.solutions.SolutionAnnotator", "ultralytics.solutions.solutions.SolutionResults", "ultralytics.utils.plotting.colors", "BaseSolution" ], "chunk_id": "class_TrackZone_d7ec1fe6" }, { "content": "from typing import Any", "chunk_type": "import", "name": "Any", "file_path": "ultralytics\\ultralytics\\solutions\\vision_eye.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 22, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Any_c71ce5ab" }, { "content": "from ultralytics.solutions.solutions import BaseSolution, SolutionAnnotator, SolutionResults", "chunk_type": "import", "name": "BaseSolution, SolutionAnnotator, SolutionResults", "file_path": "ultralytics\\ultralytics\\solutions\\vision_eye.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 92, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_BaseSolution, SolutionAnnotator, SolutionResults_61009772" }, { "content": "from ultralytics.utils.plotting import colors", "chunk_type": "import", "name": "colors", "file_path": "ultralytics\\ultralytics\\solutions\\vision_eye.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 45, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_colors_6e3f8bf7" }, { "content": "class VisionEye(BaseSolution):\n \"\"\"\n A class to manage object detection and vision mapping in images or video streams.\n\n This class extends the BaseSolution class and provides functionality for detecting objects,\n mapping vision points, and annotating results with bounding boxes and labels.\n\n Attributes:\n vision_point (Tuple[int, int]): Coordinates (x, y) where vision will view objects and draw tracks.\n\n Methods:\n process: Process the input image to detect objects, annotate them, and apply vision mapping.\n\n Examples:\n >>> vision_eye = VisionEye()\n >>> frame = cv2.imread(\"frame.jpg\")\n >>> results = vision_eye.process(frame)\n >>> print(f\"Total detected instances: {results.total_tracks}\")\n \"\"\"\n\n def __init__(self, **kwargs: Any) -> None:\n \"\"\"\n Initialize the VisionEye class for detecting objects and applying vision mapping.\n\n Args:\n **kwargs (Any): Keyword arguments passed to the parent class and for configuring vision_point.\n \"\"\"\n super().__init__(**kwargs)\n # Set the vision point where the system will view objects and draw tracks\n self.vision_point = self.CFG[\"vision_point\"]\n\n def process(self, im0) -> SolutionResults:\n \"\"\"\n Perform object detection, vision mapping, and annotation on the input image.\n\n Args:\n im0 (np.ndarray): The input image for detection and annotation.\n\n Returns:\n (SolutionResults): Object containing the annotated image and tracking statistics.\n - plot_im: Annotated output image with bounding boxes and vision mapping\n - total_tracks: Number of tracked objects in the frame\n\n Examples:\n >>> vision_eye = VisionEye()\n >>> frame = cv2.imread(\"image.jpg\")\n >>> results = vision_eye.process(frame)\n >>> print(f\"Detected {results.total_tracks} objects\")\n \"\"\"\n self.extract_tracks(im0) # Extract tracks (bounding boxes, classes, and masks)\n annotator = SolutionAnnotator(im0, self.line_width)\n\n for cls, t_id, box, conf in zip(self.clss, self.track_ids, self.boxes, self.confs):\n # Annotate the image with bounding boxes, labels, and vision mapping\n annotator.box_label(box, label=self.adjust_box_label(cls, conf, t_id), color=colors(int(t_id), True))\n annotator.visioneye(box, self.vision_point)\n\n plot_im = annotator.result()\n self.display_output(plot_im) # Display the annotated output using the base class function\n\n # Return a SolutionResults object with the annotated image and tracking statistics\n return SolutionResults(plot_im=plot_im, total_tracks=len(self.track_ids))", "chunk_type": "class", "name": "VisionEye", "file_path": "ultralytics\\ultralytics\\solutions\\vision_eye.py", "start_line": 9, "end_line": 70, "start_col": 0, "end_col": 81, "parent_name": null, "docstring": "A class to manage object detection and vision mapping in images or video streams.\n\nThis class extends the BaseSolution class and provides functionality for detecting objects,\nmapping vision points, and annotating results with bounding boxes and labels.\n\nAttributes:\n vision_point (Tuple[int, int]): Coordinates (x, y) where vision will view objects and draw tracks.\n\nMethods:\n process: Process the input image to detect objects, annotate them, and apply vision mapping.\n\nExamples:\n >>> vision_eye = VisionEye()\n >>> frame = cv2.imread(\"frame.jpg\")\n >>> results = vision_eye.process(frame)\n >>> print(f\"Total detected instances: {results.total_tracks}\")", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.Any", "ultralytics.solutions.solutions.BaseSolution", "ultralytics.solutions.solutions.SolutionAnnotator", "ultralytics.solutions.solutions.SolutionResults", "ultralytics.utils.plotting.colors", "BaseSolution" ], "chunk_id": "class_VisionEye_8a71f3f3" }, { "content": "from .ai_gym import AIGym", "chunk_type": "import", "name": "AIGym", "file_path": "ultralytics\\ultralytics\\solutions\\__init__.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 25, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_AIGym_c56e4460" }, { "content": "from .analytics import Analytics", "chunk_type": "import", "name": "Analytics", "file_path": "ultralytics\\ultralytics\\solutions\\__init__.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 32, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Analytics_d00bee54" }, { "content": "from .distance_calculation import DistanceCalculation", "chunk_type": "import", "name": "DistanceCalculation", "file_path": "ultralytics\\ultralytics\\solutions\\__init__.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 53, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_DistanceCalculation_91f34e07" }, { "content": "from .heatmap import Heatmap", "chunk_type": "import", "name": "Heatmap", "file_path": "ultralytics\\ultralytics\\solutions\\__init__.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 28, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Heatmap_08c5d6f2" }, { "content": "from .instance_segmentation import InstanceSegmentation", "chunk_type": "import", "name": "InstanceSegmentation", "file_path": "ultralytics\\ultralytics\\solutions\\__init__.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 55, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_InstanceSegmentation_4f5ecc69" }, { "content": "from .object_blurrer import ObjectBlurrer", "chunk_type": "import", "name": "ObjectBlurrer", "file_path": "ultralytics\\ultralytics\\solutions\\__init__.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 41, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_ObjectBlurrer_e71acde0" }, { "content": "from .object_counter import ObjectCounter", "chunk_type": "import", "name": "ObjectCounter", "file_path": "ultralytics\\ultralytics\\solutions\\__init__.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 41, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_ObjectCounter_a05e4b23" }, { "content": "from .object_cropper import ObjectCropper", "chunk_type": "import", "name": "ObjectCropper", "file_path": "ultralytics\\ultralytics\\solutions\\__init__.py", "start_line": 10, "end_line": 10, "start_col": 0, "end_col": 41, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_ObjectCropper_0deaa30e" }, { "content": "from .parking_management import ParkingManagement, ParkingPtsSelection", "chunk_type": "import", "name": "ParkingManagement, ParkingPtsSelection", "file_path": "ultralytics\\ultralytics\\solutions\\__init__.py", "start_line": 11, "end_line": 11, "start_col": 0, "end_col": 70, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_ParkingManagement, ParkingPtsSelection_29d30cd2" }, { "content": "from .queue_management import QueueManager", "chunk_type": "import", "name": "QueueManager", "file_path": "ultralytics\\ultralytics\\solutions\\__init__.py", "start_line": 12, "end_line": 12, "start_col": 0, "end_col": 42, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_QueueManager_faed8198" }, { "content": "from .region_counter import RegionCounter", "chunk_type": "import", "name": "RegionCounter", "file_path": "ultralytics\\ultralytics\\solutions\\__init__.py", "start_line": 13, "end_line": 13, "start_col": 0, "end_col": 41, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_RegionCounter_0bb865cd" }, { "content": "from .security_alarm import SecurityAlarm", "chunk_type": "import", "name": "SecurityAlarm", "file_path": "ultralytics\\ultralytics\\solutions\\__init__.py", "start_line": 14, "end_line": 14, "start_col": 0, "end_col": 41, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_SecurityAlarm_e3cdaa74" }, { "content": "from .similarity_search import SearchApp, VisualAISearch", "chunk_type": "import", "name": "SearchApp, VisualAISearch", "file_path": "ultralytics\\ultralytics\\solutions\\__init__.py", "start_line": 15, "end_line": 15, "start_col": 0, "end_col": 56, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_SearchApp, VisualAISearch_0adcf032" }, { "content": "from .speed_estimation import SpeedEstimator", "chunk_type": "import", "name": "SpeedEstimator", "file_path": "ultralytics\\ultralytics\\solutions\\__init__.py", "start_line": 16, "end_line": 16, "start_col": 0, "end_col": 44, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_SpeedEstimator_5abd8870" }, { "content": "from .streamlit_inference import Inference", "chunk_type": "import", "name": "Inference", "file_path": "ultralytics\\ultralytics\\solutions\\__init__.py", "start_line": 17, "end_line": 17, "start_col": 0, "end_col": 42, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Inference_c613e25e" }, { "content": "from .trackzone import TrackZone", "chunk_type": "import", "name": "TrackZone", "file_path": "ultralytics\\ultralytics\\solutions\\__init__.py", "start_line": 18, "end_line": 18, "start_col": 0, "end_col": 32, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_TrackZone_586fe219" }, { "content": "from .vision_eye import VisionEye", "chunk_type": "import", "name": "VisionEye", "file_path": "ultralytics\\ultralytics\\solutions\\__init__.py", "start_line": 19, "end_line": 19, "start_col": 0, "end_col": 33, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_VisionEye_37d08a7d" }, { "content": "__all__ = (\n \"ObjectCounter\",\n \"ObjectCropper\",\n \"ObjectBlurrer\",\n \"AIGym\",\n \"RegionCounter\",\n \"SecurityAlarm\",\n \"Heatmap\",\n \"InstanceSegmentation\",\n \"VisionEye\",\n \"SpeedEstimator\",\n \"DistanceCalculation\",\n \"QueueManager\",\n \"ParkingManagement\",\n \"ParkingPtsSelection\",\n \"Analytics\",\n \"Inference\",\n \"TrackZone\",\n \"SearchApp\",\n \"VisualAISearch\",\n)", "chunk_type": "variable", "name": "__all__", "file_path": "ultralytics\\ultralytics\\solutions\\__init__.py", "start_line": 21, "end_line": 41, "start_col": 0, "end_col": 1, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable___all___5c311cc6" }, { "content": "from collections import OrderedDict", "chunk_type": "import", "name": "OrderedDict", "file_path": "ultralytics\\ultralytics\\trackers\\basetrack.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 35, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_OrderedDict_aa0536c1" }, { "content": "from typing import Any", "chunk_type": "import", "name": "Any", "file_path": "ultralytics\\ultralytics\\trackers\\basetrack.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 22, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Any_879757ea" }, { "content": "import numpy as np", "chunk_type": "import", "name": "numpy", "file_path": "ultralytics\\ultralytics\\trackers\\basetrack.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_numpy_8c9fb52b" }, { "content": "class TrackState:\n \"\"\"\n Enumeration class representing the possible states of an object being tracked.\n\n Attributes:\n New (int): State when the object is newly detected.\n Tracked (int): State when the object is successfully tracked in subsequent frames.\n Lost (int): State when the object is no longer tracked.\n Removed (int): State when the object is removed from tracking.\n\n Examples:\n >>> state = TrackState.New\n >>> if state == TrackState.New:\n >>> print(\"Object is newly detected.\")\n \"\"\"\n\n New = 0\n Tracked = 1\n Lost = 2\n Removed = 3", "chunk_type": "class", "name": "TrackState", "file_path": "ultralytics\\ultralytics\\trackers\\basetrack.py", "start_line": 10, "end_line": 29, "start_col": 0, "end_col": 15, "parent_name": null, "docstring": "Enumeration class representing the possible states of an object being tracked.\n\nAttributes:\n New (int): State when the object is newly detected.\n Tracked (int): State when the object is successfully tracked in subsequent frames.\n Lost (int): State when the object is no longer tracked.\n Removed (int): State when the object is removed from tracking.\n\nExamples:\n >>> state = TrackState.New\n >>> if state == TrackState.New:\n >>> print(\"Object is newly detected.\")", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "collections.OrderedDict", "typing.Any", "numpy" ], "chunk_id": "class_TrackState_516af351" }, { "content": "class BaseTrack:\n \"\"\"\n Base class for object tracking, providing foundational attributes and methods.\n\n Attributes:\n _count (int): Class-level counter for unique track IDs.\n track_id (int): Unique identifier for the track.\n is_activated (bool): Flag indicating whether the track is currently active.\n state (TrackState): Current state of the track.\n history (OrderedDict): Ordered history of the track's states.\n features (list): List of features extracted from the object for tracking.\n curr_feature (Any): The current feature of the object being tracked.\n score (float): The confidence score of the tracking.\n start_frame (int): The frame number where tracking started.\n frame_id (int): The most recent frame ID processed by the track.\n time_since_update (int): Frames passed since the last update.\n location (tuple): The location of the object in the context of multi-camera tracking.\n\n Methods:\n end_frame: Returns the ID of the last frame where the object was tracked.\n next_id: Increments and returns the next global track ID.\n activate: Abstract method to activate the track.\n predict: Abstract method to predict the next state of the track.\n update: Abstract method to update the track with new data.\n mark_lost: Marks the track as lost.\n mark_removed: Marks the track as removed.\n reset_id: Resets the global track ID counter.\n\n Examples:\n Initialize a new track and mark it as lost:\n >>> track = BaseTrack()\n >>> track.mark_lost()\n >>> print(track.state) # Output: 2 (TrackState.Lost)\n \"\"\"\n\n _count = 0\n\n def __init__(self):\n \"\"\"Initialize a new track with a unique ID and foundational tracking attributes.\"\"\"\n self.track_id = 0\n self.is_activated = False\n self.state = TrackState.New\n self.history = OrderedDict()\n self.features = []\n self.curr_feature = None\n self.score = 0\n self.start_frame = 0\n self.frame_id = 0\n self.time_since_update = 0\n self.location = (np.inf, np.inf)\n\n @property\n def end_frame(self) -> int:\n \"\"\"Return the ID of the most recent frame where the object was tracked.\"\"\"\n return self.frame_id\n\n @staticmethod\n def next_id() -> int:\n \"\"\"Increment and return the next unique global track ID for object tracking.\"\"\"\n BaseTrack._count += 1\n return BaseTrack._count\n\n def activate(self, *args: Any) -> None:\n \"\"\"Activate the track with provided arguments, initializing necessary attributes for tracking.\"\"\"\n raise NotImplementedError\n\n def predict(self) -> None:\n \"\"\"Predict the next state of the track based on the current state and tracking model.\"\"\"\n raise NotImplementedError\n\n def update(self, *args: Any, **kwargs: Any) -> None:\n \"\"\"Update the track with new observations and data, modifying its state and attributes accordingly.\"\"\"\n raise NotImplementedError\n\n def mark_lost(self) -> None:\n \"\"\"Mark the track as lost by updating its state to TrackState.Lost.\"\"\"\n self.state = TrackState.Lost\n\n def mark_removed(self) -> None:\n \"\"\"Mark the track as removed by setting its state to TrackState.Removed.\"\"\"\n self.state = TrackState.Removed\n\n @staticmethod\n def reset_id() -> None:\n \"\"\"Reset the global track ID counter to its initial value.\"\"\"\n BaseTrack._count = 0", "chunk_type": "class", "name": "BaseTrack", "file_path": "ultralytics\\ultralytics\\trackers\\basetrack.py", "start_line": 32, "end_line": 117, "start_col": 0, "end_col": 28, "parent_name": null, "docstring": "Base class for object tracking, providing foundational attributes and methods.\n\nAttributes:\n _count (int): Class-level counter for unique track IDs.\n track_id (int): Unique identifier for the track.\n is_activated (bool): Flag indicating whether the track is currently active.\n state (TrackState): Current state of the track.\n history (OrderedDict): Ordered history of the track's states.\n features (list): List of features extracted from the object for tracking.\n curr_feature (Any): The current feature of the object being tracked.\n score (float): The confidence score of the tracking.\n start_frame (int): The frame number where tracking started.\n frame_id (int): The most recent frame ID processed by the track.\n time_since_update (int): Frames passed since the last update.\n location (tuple): The location of the object in the context of multi-camera tracking.\n\nMethods:\n end_frame: Returns the ID of the last frame where the object was tracked.\n next_id: Increments and returns the next global track ID.\n activate: Abstract method to activate the track.\n predict: Abstract method to predict the next state of the track.\n update: Abstract method to update the track with new data.\n mark_lost: Marks the track as lost.\n mark_removed: Marks the track as removed.\n reset_id: Resets the global track ID counter.\n\nExamples:\n Initialize a new track and mark it as lost:\n >>> track = BaseTrack()\n >>> track.mark_lost()\n >>> print(track.state) # Output: 2 (TrackState.Lost)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "collections.OrderedDict", "typing.Any", "numpy" ], "chunk_id": "class_BaseTrack_9a34aaf0" }, { "content": "from collections import deque", "chunk_type": "import", "name": "deque", "file_path": "ultralytics\\ultralytics\\trackers\\bot_sort.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 29, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_deque_9cdabf00" }, { "content": "from typing import Any, List, Optional", "chunk_type": "import", "name": "Any, List, Optional", "file_path": "ultralytics\\ultralytics\\trackers\\bot_sort.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 38, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Any, List, Optional_da76344b" }, { "content": "import numpy as np", "chunk_type": "import", "name": "numpy", "file_path": "ultralytics\\ultralytics\\trackers\\bot_sort.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_numpy_b77088f8" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\ultralytics\\trackers\\bot_sort.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_375aabac" }, { "content": "from ultralytics.utils.ops import xywh2xyxy", "chunk_type": "import", "name": "xywh2xyxy", "file_path": "ultralytics\\ultralytics\\trackers\\bot_sort.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 43, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_xywh2xyxy_e1e81eaf" }, { "content": "from ultralytics.utils.plotting import save_one_box", "chunk_type": "import", "name": "save_one_box", "file_path": "ultralytics\\ultralytics\\trackers\\bot_sort.py", "start_line": 10, "end_line": 10, "start_col": 0, "end_col": 51, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_save_one_box_fa1c4742" }, { "content": "from .basetrack import TrackState", "chunk_type": "import", "name": "TrackState", "file_path": "ultralytics\\ultralytics\\trackers\\bot_sort.py", "start_line": 12, "end_line": 12, "start_col": 0, "end_col": 33, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_TrackState_fb6e3ae9" }, { "content": "from .byte_tracker import BYTETracker, STrack", "chunk_type": "import", "name": "BYTETracker, STrack", "file_path": "ultralytics\\ultralytics\\trackers\\bot_sort.py", "start_line": 13, "end_line": 13, "start_col": 0, "end_col": 45, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_BYTETracker, STrack_99380cc4" }, { "content": "from .utils import matching", "chunk_type": "import", "name": "matching", "file_path": "ultralytics\\ultralytics\\trackers\\bot_sort.py", "start_line": 14, "end_line": 14, "start_col": 0, "end_col": 27, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_matching_ece3f4dd" }, { "content": "from .utils.gmc import GMC", "chunk_type": "import", "name": "GMC", "file_path": "ultralytics\\ultralytics\\trackers\\bot_sort.py", "start_line": 15, "end_line": 15, "start_col": 0, "end_col": 26, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_GMC_35d930cc" }, { "content": "from .utils.kalman_filter import KalmanFilterXYWH", "chunk_type": "import", "name": "KalmanFilterXYWH", "file_path": "ultralytics\\ultralytics\\trackers\\bot_sort.py", "start_line": 16, "end_line": 16, "start_col": 0, "end_col": 49, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_KalmanFilterXYWH_23048d5d" }, { "content": "class BOTrack(STrack):\n \"\"\"\n An extended version of the STrack class for YOLO, adding object tracking features.\n\n This class extends the STrack class to include additional functionalities for object tracking, such as feature\n smoothing, Kalman filter prediction, and reactivation of tracks.\n\n Attributes:\n shared_kalman (KalmanFilterXYWH): A shared Kalman filter for all instances of BOTrack.\n smooth_feat (np.ndarray): Smoothed feature vector.\n curr_feat (np.ndarray): Current feature vector.\n features (deque): A deque to store feature vectors with a maximum length defined by `feat_history`.\n alpha (float): Smoothing factor for the exponential moving average of features.\n mean (np.ndarray): The mean state of the Kalman filter.\n covariance (np.ndarray): The covariance matrix of the Kalman filter.\n\n Methods:\n update_features: Update features vector and smooth it using exponential moving average.\n predict: Predict the mean and covariance using Kalman filter.\n re_activate: Reactivate a track with updated features and optionally new ID.\n update: Update the track with new detection and frame ID.\n tlwh: Property that gets the current position in tlwh format `(top left x, top left y, width, height)`.\n multi_predict: Predict the mean and covariance of multiple object tracks using shared Kalman filter.\n convert_coords: Convert tlwh bounding box coordinates to xywh format.\n tlwh_to_xywh: Convert bounding box to xywh format `(center x, center y, width, height)`.\n\n Examples:\n Create a BOTrack instance and update its features\n >>> bo_track = BOTrack(tlwh=[100, 50, 80, 40], score=0.9, cls=1, feat=np.random.rand(128))\n >>> bo_track.predict()\n >>> new_track = BOTrack(tlwh=[110, 60, 80, 40], score=0.85, cls=1, feat=np.random.rand(128))\n >>> bo_track.update(new_track, frame_id=2)\n \"\"\"\n\n shared_kalman = KalmanFilterXYWH()\n\n def __init__(\n self, tlwh: np.ndarray, score: float, cls: int, feat: Optional[np.ndarray] = None, feat_history: int = 50\n ):\n \"\"\"\n Initialize a BOTrack object with temporal parameters, such as feature history, alpha, and current features.\n\n Args:\n tlwh (np.ndarray): Bounding box coordinates in tlwh format (top left x, top left y, width, height).\n score (float): Confidence score of the detection.\n cls (int): Class ID of the detected object.\n feat (np.ndarray, optional): Feature vector associated with the detection.\n feat_history (int): Maximum length of the feature history deque.\n\n Examples:\n Initialize a BOTrack object with bounding box, score, class ID, and feature vector\n >>> tlwh = np.array([100, 50, 80, 120])\n >>> score = 0.9\n >>> cls = 1\n >>> feat = np.random.rand(128)\n >>> bo_track = BOTrack(tlwh, score, cls, feat)\n \"\"\"\n super().__init__(tlwh, score, cls)\n\n self.smooth_feat = None\n self.curr_feat = None\n if feat is not None:\n self.update_features(feat)\n self.features = deque([], maxlen=feat_history)\n self.alpha = 0.9\n\n def update_features(self, feat: np.ndarray) -> None:\n \"\"\"Update the feature vector and apply exponential moving average smoothing.\"\"\"\n feat /= np.linalg.norm(feat)\n self.curr_feat = feat\n if self.smooth_feat is None:\n self.smooth_feat = feat\n else:\n self.smooth_feat = self.alpha * self.smooth_feat + (1 - self.alpha) * feat\n self.features.append(feat)\n self.smooth_feat /= np.linalg.norm(self.smooth_feat)\n\n def predict(self) -> None:\n \"\"\"Predict the object's future state using the Kalman filter to update its mean and covariance.\"\"\"\n mean_state = self.mean.copy()\n if self.state != TrackState.Tracked:\n mean_state[6] = 0\n mean_state[7] = 0\n\n self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)\n\n def re_activate(self, new_track: \"BOTrack\", frame_id: int, new_id: bool = False) -> None:\n \"\"\"Reactivate a track with updated features and optionally assign a new ID.\"\"\"\n if new_track.curr_feat is not None:\n self.update_features(new_track.curr_feat)\n super().re_activate(new_track, frame_id, new_id)\n\n def update(self, new_track: \"BOTrack\", frame_id: int) -> None:\n \"\"\"Update the track with new detection information and the current frame ID.\"\"\"\n if new_track.curr_feat is not None:\n self.update_features(new_track.curr_feat)\n super().update(new_track, frame_id)\n\n @property\n def tlwh(self) -> np.ndarray:\n \"\"\"Return the current bounding box position in `(top left x, top left y, width, height)` format.\"\"\"\n if self.mean is None:\n return self._tlwh.copy()\n ret = self.mean[:4].copy()\n ret[:2] -= ret[2:] / 2\n return ret\n\n @staticmethod\n def multi_predict(stracks: List[\"BOTrack\"]) -> None:\n \"\"\"Predict the mean and covariance for multiple object tracks using a shared Kalman filter.\"\"\"\n if len(stracks) <= 0:\n return\n multi_mean = np.asarray([st.mean.copy() for st in stracks])\n multi_covariance = np.asarray([st.covariance for st in stracks])\n for i, st in enumerate(stracks):\n if st.state != TrackState.Tracked:\n multi_mean[i][6] = 0\n multi_mean[i][7] = 0\n multi_mean, multi_covariance = BOTrack.shared_kalman.multi_predict(multi_mean, multi_covariance)\n for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):\n stracks[i].mean = mean\n stracks[i].covariance = cov\n\n def convert_coords(self, tlwh: np.ndarray) -> np.ndarray:\n \"\"\"Convert tlwh bounding box coordinates to xywh format.\"\"\"\n return self.tlwh_to_xywh(tlwh)\n\n @staticmethod\n def tlwh_to_xywh(tlwh: np.ndarray) -> np.ndarray:\n \"\"\"Convert bounding box from tlwh (top-left-width-height) to xywh (center-x-center-y-width-height) format.\"\"\"\n ret = np.asarray(tlwh).copy()\n ret[:2] += ret[2:] / 2\n return ret", "chunk_type": "class", "name": "BOTrack", "file_path": "ultralytics\\ultralytics\\trackers\\bot_sort.py", "start_line": 19, "end_line": 151, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": "An extended version of the STrack class for YOLO, adding object tracking features.\n\nThis class extends the STrack class to include additional functionalities for object tracking, such as feature\nsmoothing, Kalman filter prediction, and reactivation of tracks.\n\nAttributes:\n shared_kalman (KalmanFilterXYWH): A shared Kalman filter for all instances of BOTrack.\n smooth_feat (np.ndarray): Smoothed feature vector.\n curr_feat (np.ndarray): Current feature vector.\n features (deque): A deque to store feature vectors with a maximum length defined by `feat_history`.\n alpha (float): Smoothing factor for the exponential moving average of features.\n mean (np.ndarray): The mean state of the Kalman filter.\n covariance (np.ndarray): The covariance matrix of the Kalman filter.\n\nMethods:\n update_features: Update features vector and smooth it using exponential moving average.\n predict: Predict the mean and covariance using Kalman filter.\n re_activate: Reactivate a track with updated features and optionally new ID.\n update: Update the track with new detection and frame ID.\n tlwh: Property that gets the current position in tlwh format `(top left x, top left y, width, height)`.\n multi_predict: Predict the mean and covariance of multiple object tracks using shared Kalman filter.\n convert_coords: Convert tlwh bounding box coordinates to xywh format.\n tlwh_to_xywh: Convert bounding box to xywh format `(center x, center y, width, height)`.\n\nExamples:\n Create a BOTrack instance and update its features\n >>> bo_track = BOTrack(tlwh=[100, 50, 80, 40], score=0.9, cls=1, feat=np.random.rand(128))\n >>> bo_track.predict()\n >>> new_track = BOTrack(tlwh=[110, 60, 80, 40], score=0.85, cls=1, feat=np.random.rand(128))\n >>> bo_track.update(new_track, frame_id=2)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "collections.deque", "typing.Any", "typing.List", "typing.Optional", "numpy", "torch", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.plotting.save_one_box", "basetrack.TrackState", "byte_tracker.BYTETracker", "byte_tracker.STrack", "utils.matching", "utils.gmc.GMC", "utils.kalman_filter.KalmanFilterXYWH", "ultralytics.YOLO", "STrack" ], "chunk_id": "class_BOTrack_878d3d31" }, { "content": "class BOTSORT(BYTETracker):\n \"\"\"\n An extended version of the BYTETracker class for YOLO, designed for object tracking with ReID and GMC algorithm.\n\n Attributes:\n proximity_thresh (float): Threshold for spatial proximity (IoU) between tracks and detections.\n appearance_thresh (float): Threshold for appearance similarity (ReID embeddings) between tracks and detections.\n encoder (Any): Object to handle ReID embeddings, set to None if ReID is not enabled.\n gmc (GMC): An instance of the GMC algorithm for data association.\n args (Any): Parsed command-line arguments containing tracking parameters.\n\n Methods:\n get_kalmanfilter: Return an instance of KalmanFilterXYWH for object tracking.\n init_track: Initialize track with detections, scores, and classes.\n get_dists: Get distances between tracks and detections using IoU and (optionally) ReID.\n multi_predict: Predict and track multiple objects with a YOLO model.\n reset: Reset the BOTSORT tracker to its initial state.\n\n Examples:\n Initialize BOTSORT and process detections\n >>> bot_sort = BOTSORT(args, frame_rate=30)\n >>> bot_sort.init_track(dets, scores, cls, img)\n >>> bot_sort.multi_predict(tracks)\n\n Note:\n The class is designed to work with a YOLO object detection model and supports ReID only if enabled via args.\n \"\"\"\n\n def __init__(self, args: Any, frame_rate: int = 30):\n \"\"\"\n Initialize BOTSORT object with ReID module and GMC algorithm.\n\n Args:\n args (Any): Parsed command-line arguments containing tracking parameters.\n frame_rate (int): Frame rate of the video being processed.\n\n Examples:\n Initialize BOTSORT with command-line arguments and a specified frame rate:\n >>> args = parse_args()\n >>> bot_sort = BOTSORT(args, frame_rate=30)\n \"\"\"\n super().__init__(args, frame_rate)\n self.gmc = GMC(method=args.gmc_method)\n\n # ReID module\n self.proximity_thresh = args.proximity_thresh\n self.appearance_thresh = args.appearance_thresh\n self.encoder = (\n (lambda feats, s: [f.cpu().numpy() for f in feats]) # native features do not require any model\n if args.with_reid and self.args.model == \"auto\"\n else ReID(args.model)\n if args.with_reid\n else None\n )\n\n def get_kalmanfilter(self) -> KalmanFilterXYWH:\n \"\"\"Return an instance of KalmanFilterXYWH for predicting and updating object states in the tracking process.\"\"\"\n return KalmanFilterXYWH()\n\n def init_track(\n self, dets: np.ndarray, scores: np.ndarray, cls: np.ndarray, img: Optional[np.ndarray] = None\n ) -> List[BOTrack]:\n \"\"\"Initialize object tracks using detection bounding boxes, scores, class labels, and optional ReID features.\"\"\"\n if len(dets) == 0:\n return []\n if self.args.with_reid and self.encoder is not None:\n features_keep = self.encoder(img, dets)\n return [BOTrack(xyxy, s, c, f) for (xyxy, s, c, f) in zip(dets, scores, cls, features_keep)] # detections\n else:\n return [BOTrack(xyxy, s, c) for (xyxy, s, c) in zip(dets, scores, cls)] # detections\n\n def get_dists(self, tracks: List[BOTrack], detections: List[BOTrack]) -> np.ndarray:\n \"\"\"Calculate distances between tracks and detections using IoU and optionally ReID embeddings.\"\"\"\n dists = matching.iou_distance(tracks, detections)\n dists_mask = dists > (1 - self.proximity_thresh)\n\n if self.args.fuse_score:\n dists = matching.fuse_score(dists, detections)\n\n if self.args.with_reid and self.encoder is not None:\n emb_dists = matching.embedding_distance(tracks, detections) / 2.0\n emb_dists[emb_dists > (1 - self.appearance_thresh)] = 1.0\n emb_dists[dists_mask] = 1.0\n dists = np.minimum(dists, emb_dists)\n return dists\n\n def multi_predict(self, tracks: List[BOTrack]) -> None:\n \"\"\"Predict the mean and covariance of multiple object tracks using a shared Kalman filter.\"\"\"\n BOTrack.multi_predict(tracks)\n\n def reset(self) -> None:\n \"\"\"Reset the BOTSORT tracker to its initial state, clearing all tracked objects and internal states.\"\"\"\n super().reset()\n self.gmc.reset_params()", "chunk_type": "class", "name": "BOTSORT", "file_path": "ultralytics\\ultralytics\\trackers\\bot_sort.py", "start_line": 154, "end_line": 247, "start_col": 0, "end_col": 31, "parent_name": null, "docstring": "An extended version of the BYTETracker class for YOLO, designed for object tracking with ReID and GMC algorithm.\n\nAttributes:\n proximity_thresh (float): Threshold for spatial proximity (IoU) between tracks and detections.\n appearance_thresh (float): Threshold for appearance similarity (ReID embeddings) between tracks and detections.\n encoder (Any): Object to handle ReID embeddings, set to None if ReID is not enabled.\n gmc (GMC): An instance of the GMC algorithm for data association.\n args (Any): Parsed command-line arguments containing tracking parameters.\n\nMethods:\n get_kalmanfilter: Return an instance of KalmanFilterXYWH for object tracking.\n init_track: Initialize track with detections, scores, and classes.\n get_dists: Get distances between tracks and detections using IoU and (optionally) ReID.\n multi_predict: Predict and track multiple objects with a YOLO model.\n reset: Reset the BOTSORT tracker to its initial state.\n\nExamples:\n Initialize BOTSORT and process detections\n >>> bot_sort = BOTSORT(args, frame_rate=30)\n >>> bot_sort.init_track(dets, scores, cls, img)\n >>> bot_sort.multi_predict(tracks)\n\nNote:\n The class is designed to work with a YOLO object detection model and supports ReID only if enabled via args.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "collections.deque", "typing.Any", "typing.List", "typing.Optional", "numpy", "torch", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.plotting.save_one_box", "basetrack.TrackState", "byte_tracker.BYTETracker", "byte_tracker.STrack", "utils.matching", "utils.gmc.GMC", "utils.kalman_filter.KalmanFilterXYWH", "ultralytics.YOLO", "BYTETracker" ], "chunk_id": "class_BOTSORT_02df815b" }, { "content": "class ReID:\n \"\"\"YOLO model as encoder for re-identification.\"\"\"\n\n def __init__(self, model: str):\n \"\"\"\n Initialize encoder for re-identification.\n\n Args:\n model (str): Path to the YOLO model for re-identification.\n \"\"\"\n from ultralytics import YOLO\n\n self.model = YOLO(model)\n self.model(embed=[len(self.model.model.model) - 2 if \".pt\" in model else -1], verbose=False, save=False) # init\n\n def __call__(self, img: np.ndarray, dets: np.ndarray) -> List[np.ndarray]:\n \"\"\"Extract embeddings for detected objects.\"\"\"\n feats = self.model.predictor(\n [save_one_box(det, img, save=False) for det in xywh2xyxy(torch.from_numpy(dets[:, :4]))]\n )\n if len(feats) != dets.shape[0] and feats[0].shape[0] == dets.shape[0]:\n feats = feats[0] # batched prediction with non-PyTorch backend\n return [f.cpu().numpy() for f in feats]", "chunk_type": "class", "name": "ReID", "file_path": "ultralytics\\ultralytics\\trackers\\bot_sort.py", "start_line": 250, "end_line": 272, "start_col": 0, "end_col": 47, "parent_name": null, "docstring": "YOLO model as encoder for re-identification.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "collections.deque", "typing.Any", "typing.List", "typing.Optional", "numpy", "torch", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.plotting.save_one_box", "basetrack.TrackState", "byte_tracker.BYTETracker", "byte_tracker.STrack", "utils.matching", "utils.gmc.GMC", "utils.kalman_filter.KalmanFilterXYWH", "ultralytics.YOLO" ], "chunk_id": "class_ReID_2fea07e5" }, { "content": "from typing import Any, List, Optional, Tuple", "chunk_type": "import", "name": "Any, List, Optional, Tuple", "file_path": "ultralytics\\ultralytics\\trackers\\byte_tracker.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 45, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Any, List, Optional, Tuple_eafa25fe" }, { "content": "import numpy as np", "chunk_type": "import", "name": "numpy", "file_path": "ultralytics\\ultralytics\\trackers\\byte_tracker.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_numpy_97b17d82" }, { "content": "from ..utils import LOGGER", "chunk_type": "import", "name": "LOGGER", "file_path": "ultralytics\\ultralytics\\trackers\\byte_tracker.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 26, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_LOGGER_10c4c7b7" }, { "content": "from ..utils.ops import xywh2ltwh", "chunk_type": "import", "name": "xywh2ltwh", "file_path": "ultralytics\\ultralytics\\trackers\\byte_tracker.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 33, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_xywh2ltwh_6391db06" }, { "content": "from .basetrack import BaseTrack, TrackState", "chunk_type": "import", "name": "BaseTrack, TrackState", "file_path": "ultralytics\\ultralytics\\trackers\\byte_tracker.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 44, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_BaseTrack, TrackState_d0ca7671" }, { "content": "from .utils import matching", "chunk_type": "import", "name": "matching", "file_path": "ultralytics\\ultralytics\\trackers\\byte_tracker.py", "start_line": 10, "end_line": 10, "start_col": 0, "end_col": 27, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_matching_89619227" }, { "content": "from .utils.kalman_filter import KalmanFilterXYAH", "chunk_type": "import", "name": "KalmanFilterXYAH", "file_path": "ultralytics\\ultralytics\\trackers\\byte_tracker.py", "start_line": 11, "end_line": 11, "start_col": 0, "end_col": 49, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_KalmanFilterXYAH_887c85f7" }, { "content": "class STrack(BaseTrack):\n \"\"\"\n Single object tracking representation that uses Kalman filtering for state estimation.\n\n This class is responsible for storing all the information regarding individual tracklets and performs state updates\n and predictions based on Kalman filter.\n\n Attributes:\n shared_kalman (KalmanFilterXYAH): Shared Kalman filter used across all STrack instances for prediction.\n _tlwh (np.ndarray): Private attribute to store top-left corner coordinates and width and height of bounding box.\n kalman_filter (KalmanFilterXYAH): Instance of Kalman filter used for this particular object track.\n mean (np.ndarray): Mean state estimate vector.\n covariance (np.ndarray): Covariance of state estimate.\n is_activated (bool): Boolean flag indicating if the track has been activated.\n score (float): Confidence score of the track.\n tracklet_len (int): Length of the tracklet.\n cls (Any): Class label for the object.\n idx (int): Index or identifier for the object.\n frame_id (int): Current frame ID.\n start_frame (int): Frame where the object was first detected.\n angle (float | None): Optional angle information for oriented bounding boxes.\n\n Methods:\n predict: Predict the next state of the object using Kalman filter.\n multi_predict: Predict the next states for multiple tracks.\n multi_gmc: Update multiple track states using a homography matrix.\n activate: Activate a new tracklet.\n re_activate: Reactivate a previously lost tracklet.\n update: Update the state of a matched track.\n convert_coords: Convert bounding box to x-y-aspect-height format.\n tlwh_to_xyah: Convert tlwh bounding box to xyah format.\n\n Examples:\n Initialize and activate a new track\n >>> track = STrack(xywh=[100, 200, 50, 80, 0], score=0.9, cls=\"person\")\n >>> track.activate(kalman_filter=KalmanFilterXYAH(), frame_id=1)\n \"\"\"\n\n shared_kalman = KalmanFilterXYAH()\n\n def __init__(self, xywh: List[float], score: float, cls: Any):\n \"\"\"\n Initialize a new STrack instance.\n\n Args:\n xywh (List[float]): Bounding box coordinates and dimensions in the format (x, y, w, h, [a], idx), where\n (x, y) is the center, (w, h) are width and height, [a] is optional aspect ratio, and idx is the id.\n score (float): Confidence score of the detection.\n cls (Any): Class label for the detected object.\n\n Examples:\n >>> xywh = [100.0, 150.0, 50.0, 75.0, 1]\n >>> score = 0.9\n >>> cls = \"person\"\n >>> track = STrack(xywh, score, cls)\n \"\"\"\n super().__init__()\n # xywh+idx or xywha+idx\n assert len(xywh) in {5, 6}, f\"expected 5 or 6 values but got {len(xywh)}\"\n self._tlwh = np.asarray(xywh2ltwh(xywh[:4]), dtype=np.float32)\n self.kalman_filter = None\n self.mean, self.covariance = None, None\n self.is_activated = False\n\n self.score = score\n self.tracklet_len = 0\n self.cls = cls\n self.idx = xywh[-1]\n self.angle = xywh[4] if len(xywh) == 6 else None\n\n def predict(self):\n \"\"\"Predict the next state (mean and covariance) of the object using the Kalman filter.\"\"\"\n mean_state = self.mean.copy()\n if self.state != TrackState.Tracked:\n mean_state[7] = 0\n self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)\n\n @staticmethod\n def multi_predict(stracks: List[\"STrack\"]):\n \"\"\"Perform multi-object predictive tracking using Kalman filter for the provided list of STrack instances.\"\"\"\n if len(stracks) <= 0:\n return\n multi_mean = np.asarray([st.mean.copy() for st in stracks])\n multi_covariance = np.asarray([st.covariance for st in stracks])\n for i, st in enumerate(stracks):\n if st.state != TrackState.Tracked:\n multi_mean[i][7] = 0\n multi_mean, multi_covariance = STrack.shared_kalman.multi_predict(multi_mean, multi_covariance)\n for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):\n stracks[i].mean = mean\n stracks[i].covariance = cov\n\n @staticmethod\n def multi_gmc(stracks: List[\"STrack\"], H: np.ndarray = np.eye(2, 3)):\n \"\"\"Update state tracks positions and covariances using a homography matrix for multiple tracks.\"\"\"\n if len(stracks) > 0:\n multi_mean = np.asarray([st.mean.copy() for st in stracks])\n multi_covariance = np.asarray([st.covariance for st in stracks])\n\n R = H[:2, :2]\n R8x8 = np.kron(np.eye(4, dtype=float), R)\n t = H[:2, 2]\n\n for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):\n mean = R8x8.dot(mean)\n mean[:2] += t\n cov = R8x8.dot(cov).dot(R8x8.transpose())\n\n stracks[i].mean = mean\n stracks[i].covariance = cov\n\n def activate(self, kalman_filter: KalmanFilterXYAH, frame_id: int):\n \"\"\"Activate a new tracklet using the provided Kalman filter and initialize its state and covariance.\"\"\"\n self.kalman_filter = kalman_filter\n self.track_id = self.next_id()\n self.mean, self.covariance = self.kalman_filter.initiate(self.convert_coords(self._tlwh))\n\n self.tracklet_len = 0\n self.state = TrackState.Tracked\n if frame_id == 1:\n self.is_activated = True\n self.frame_id = frame_id\n self.start_frame = frame_id\n\n def re_activate(self, new_track: \"STrack\", frame_id: int, new_id: bool = False):\n \"\"\"Reactivate a previously lost track using new detection data and update its state and attributes.\"\"\"\n self.mean, self.covariance = self.kalman_filter.update(\n self.mean, self.covariance, self.convert_coords(new_track.tlwh)\n )\n self.tracklet_len = 0\n self.state = TrackState.Tracked\n self.is_activated = True\n self.frame_id = frame_id\n if new_id:\n self.track_id = self.next_id()\n self.score = new_track.score\n self.cls = new_track.cls\n self.angle = new_track.angle\n self.idx = new_track.idx\n\n def update(self, new_track: \"STrack\", frame_id: int):\n \"\"\"\n Update the state of a matched track.\n\n Args:\n new_track (STrack): The new track containing updated information.\n frame_id (int): The ID of the current frame.\n\n Examples:\n Update the state of a track with new detection information\n >>> track = STrack([100, 200, 50, 80, 0.9, 1])\n >>> new_track = STrack([105, 205, 55, 85, 0.95, 1])\n >>> track.update(new_track, 2)\n \"\"\"\n self.frame_id = frame_id\n self.tracklet_len += 1\n\n new_tlwh = new_track.tlwh\n self.mean, self.covariance = self.kalman_filter.update(\n self.mean, self.covariance, self.convert_coords(new_tlwh)\n )\n self.state = TrackState.Tracked\n self.is_activated = True\n\n self.score = new_track.score\n self.cls = new_track.cls\n self.angle = new_track.angle\n self.idx = new_track.idx\n\n def convert_coords(self, tlwh: np.ndarray) -> np.ndarray:\n \"\"\"Convert a bounding box's top-left-width-height format to its x-y-aspect-height equivalent.\"\"\"\n return self.tlwh_to_xyah(tlwh)\n\n @property\n def tlwh(self) -> np.ndarray:\n \"\"\"Get the bounding box in top-left-width-height format from the current state estimate.\"\"\"\n if self.mean is None:\n return self._tlwh.copy()\n ret = self.mean[:4].copy()\n ret[2] *= ret[3]\n ret[:2] -= ret[2:] / 2\n return ret\n\n @property\n def xyxy(self) -> np.ndarray:\n \"\"\"Convert bounding box from (top left x, top left y, width, height) to (min x, min y, max x, max y) format.\"\"\"\n ret = self.tlwh.copy()\n ret[2:] += ret[:2]\n return ret\n\n @staticmethod\n def tlwh_to_xyah(tlwh: np.ndarray) -> np.ndarray:\n \"\"\"Convert bounding box from tlwh format to center-x-center-y-aspect-height (xyah) format.\"\"\"\n ret = np.asarray(tlwh).copy()\n ret[:2] += ret[2:] / 2\n ret[2] /= ret[3]\n return ret\n\n @property\n def xywh(self) -> np.ndarray:\n \"\"\"Get the current position of the bounding box in (center x, center y, width, height) format.\"\"\"\n ret = np.asarray(self.tlwh).copy()\n ret[:2] += ret[2:] / 2\n return ret\n\n @property\n def xywha(self) -> np.ndarray:\n \"\"\"Get position in (center x, center y, width, height, angle) format, warning if angle is missing.\"\"\"\n if self.angle is None:\n LOGGER.warning(\"`angle` attr not found, returning `xywh` instead.\")\n return self.xywh\n return np.concatenate([self.xywh, self.angle[None]])\n\n @property\n def result(self) -> List[float]:\n \"\"\"Get the current tracking results in the appropriate bounding box format.\"\"\"\n coords = self.xyxy if self.angle is None else self.xywha\n return coords.tolist() + [self.track_id, self.score, self.cls, self.idx]\n\n def __repr__(self) -> str:\n \"\"\"Return a string representation of the STrack object including start frame, end frame, and track ID.\"\"\"\n return f\"OT_{self.track_id}_({self.start_frame}-{self.end_frame})\"", "chunk_type": "class", "name": "STrack", "file_path": "ultralytics\\ultralytics\\trackers\\byte_tracker.py", "start_line": 14, "end_line": 235, "start_col": 0, "end_col": 74, "parent_name": null, "docstring": "Single object tracking representation that uses Kalman filtering for state estimation.\n\nThis class is responsible for storing all the information regarding individual tracklets and performs state updates\nand predictions based on Kalman filter.\n\nAttributes:\n shared_kalman (KalmanFilterXYAH): Shared Kalman filter used across all STrack instances for prediction.\n _tlwh (np.ndarray): Private attribute to store top-left corner coordinates and width and height of bounding box.\n kalman_filter (KalmanFilterXYAH): Instance of Kalman filter used for this particular object track.\n mean (np.ndarray): Mean state estimate vector.\n covariance (np.ndarray): Covariance of state estimate.\n is_activated (bool): Boolean flag indicating if the track has been activated.\n score (float): Confidence score of the track.\n tracklet_len (int): Length of the tracklet.\n cls (Any): Class label for the object.\n idx (int): Index or identifier for the object.\n frame_id (int): Current frame ID.\n start_frame (int): Frame where the object was first detected.\n angle (float | None): Optional angle information for oriented bounding boxes.\n\nMethods:\n predict: Predict the next state of the object using Kalman filter.\n multi_predict: Predict the next states for multiple tracks.\n multi_gmc: Update multiple track states using a homography matrix.\n activate: Activate a new tracklet.\n re_activate: Reactivate a previously lost tracklet.\n update: Update the state of a matched track.\n convert_coords: Convert bounding box to x-y-aspect-height format.\n tlwh_to_xyah: Convert tlwh bounding box to xyah format.\n\nExamples:\n Initialize and activate a new track\n >>> track = STrack(xywh=[100, 200, 50, 80, 0], score=0.9, cls=\"person\")\n >>> track.activate(kalman_filter=KalmanFilterXYAH(), frame_id=1)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.Any", "typing.List", "typing.Optional", "typing.Tuple", "numpy", "utils.LOGGER", "utils.ops.xywh2ltwh", "basetrack.BaseTrack", "basetrack.TrackState", "utils.matching", "utils.kalman_filter.KalmanFilterXYAH", "BaseTrack" ], "chunk_id": "class_STrack_55d317ad" }, { "content": "class BYTETracker:\n \"\"\"\n BYTETracker: A tracking algorithm built on top of YOLOv8 for object detection and tracking.\n\n This class encapsulates the functionality for initializing, updating, and managing the tracks for detected objects in a\n video sequence. It maintains the state of tracked, lost, and removed tracks over frames, utilizes Kalman filtering for\n predicting the new object locations, and performs data association.\n\n Attributes:\n tracked_stracks (List[STrack]): List of successfully activated tracks.\n lost_stracks (List[STrack]): List of lost tracks.\n removed_stracks (List[STrack]): List of removed tracks.\n frame_id (int): The current frame ID.\n args (Namespace): Command-line arguments.\n max_time_lost (int): The maximum frames for a track to be considered as 'lost'.\n kalman_filter (KalmanFilterXYAH): Kalman Filter object.\n\n Methods:\n update: Update object tracker with new detections.\n get_kalmanfilter: Return a Kalman filter object for tracking bounding boxes.\n init_track: Initialize object tracking with detections.\n get_dists: Calculate the distance between tracks and detections.\n multi_predict: Predict the location of tracks.\n reset_id: Reset the ID counter of STrack.\n reset: Reset the tracker by clearing all tracks.\n joint_stracks: Combine two lists of stracks.\n sub_stracks: Filter out the stracks present in the second list from the first list.\n remove_duplicate_stracks: Remove duplicate stracks based on IoU.\n\n Examples:\n Initialize BYTETracker and update with detection results\n >>> tracker = BYTETracker(args, frame_rate=30)\n >>> results = yolo_model.detect(image)\n >>> tracked_objects = tracker.update(results)\n \"\"\"\n\n def __init__(self, args, frame_rate: int = 30):\n \"\"\"\n Initialize a BYTETracker instance for object tracking.\n\n Args:\n args (Namespace): Command-line arguments containing tracking parameters.\n frame_rate (int): Frame rate of the video sequence.\n\n Examples:\n Initialize BYTETracker with command-line arguments and a frame rate of 30\n >>> args = Namespace(track_buffer=30)\n >>> tracker = BYTETracker(args, frame_rate=30)\n \"\"\"\n self.tracked_stracks = [] # type: List[STrack]\n self.lost_stracks = [] # type: List[STrack]\n self.removed_stracks = [] # type: List[STrack]\n\n self.frame_id = 0\n self.args = args\n self.max_time_lost = int(frame_rate / 30.0 * args.track_buffer)\n self.kalman_filter = self.get_kalmanfilter()\n self.reset_id()\n\n def update(self, results, img: Optional[np.ndarray] = None, feats: Optional[np.ndarray] = None) -> np.ndarray:\n \"\"\"Update the tracker with new detections and return the current list of tracked objects.\"\"\"\n self.frame_id += 1\n activated_stracks = []\n refind_stracks = []\n lost_stracks = []\n removed_stracks = []\n\n scores = results.conf\n bboxes = results.xywhr if hasattr(results, \"xywhr\") else results.xywh\n # Add index\n bboxes = np.concatenate([bboxes, np.arange(len(bboxes)).reshape(-1, 1)], axis=-1)\n cls = results.cls\n\n remain_inds = scores >= self.args.track_high_thresh\n inds_low = scores > self.args.track_low_thresh\n inds_high = scores < self.args.track_high_thresh\n\n inds_second = inds_low & inds_high\n dets_second = bboxes[inds_second]\n dets = bboxes[remain_inds]\n scores_keep = scores[remain_inds]\n scores_second = scores[inds_second]\n cls_keep = cls[remain_inds]\n cls_second = cls[inds_second]\n\n detections = self.init_track(dets, scores_keep, cls_keep, img if feats is None else feats)\n # Add newly detected tracklets to tracked_stracks\n unconfirmed = []\n tracked_stracks = [] # type: List[STrack]\n for track in self.tracked_stracks:\n if not track.is_activated:\n unconfirmed.append(track)\n else:\n tracked_stracks.append(track)\n # Step 2: First association, with high score detection boxes\n strack_pool = self.joint_stracks(tracked_stracks, self.lost_stracks)\n # Predict the current location with KF\n self.multi_predict(strack_pool)\n if hasattr(self, \"gmc\") and img is not None:\n # use try-except here to bypass errors from gmc module\n try:\n warp = self.gmc.apply(img, dets)\n except Exception:\n warp = np.eye(2, 3)\n STrack.multi_gmc(strack_pool, warp)\n STrack.multi_gmc(unconfirmed, warp)\n\n dists = self.get_dists(strack_pool, detections)\n matches, u_track, u_detection = matching.linear_assignment(dists, thresh=self.args.match_thresh)\n\n for itracked, idet in matches:\n track = strack_pool[itracked]\n det = detections[idet]\n if track.state == TrackState.Tracked:\n track.update(det, self.frame_id)\n activated_stracks.append(track)\n else:\n track.re_activate(det, self.frame_id, new_id=False)\n refind_stracks.append(track)\n # Step 3: Second association, with low score detection boxes association the untrack to the low score detections\n detections_second = self.init_track(dets_second, scores_second, cls_second, img if feats is None else feats)\n r_tracked_stracks = [strack_pool[i] for i in u_track if strack_pool[i].state == TrackState.Tracked]\n # TODO\n dists = matching.iou_distance(r_tracked_stracks, detections_second)\n matches, u_track, u_detection_second = matching.linear_assignment(dists, thresh=0.5)\n for itracked, idet in matches:\n track = r_tracked_stracks[itracked]\n det = detections_second[idet]\n if track.state == TrackState.Tracked:\n track.update(det, self.frame_id)\n activated_stracks.append(track)\n else:\n track.re_activate(det, self.frame_id, new_id=False)\n refind_stracks.append(track)\n\n for it in u_track:\n track = r_tracked_stracks[it]\n if track.state != TrackState.Lost:\n track.mark_lost()\n lost_stracks.append(track)\n # Deal with unconfirmed tracks, usually tracks with only one beginning frame\n detections = [detections[i] for i in u_detection]\n dists = self.get_dists(unconfirmed, detections)\n matches, u_unconfirmed, u_detection = matching.linear_assignment(dists, thresh=0.7)\n for itracked, idet in matches:\n unconfirmed[itracked].update(detections[idet], self.frame_id)\n activated_stracks.append(unconfirmed[itracked])\n for it in u_unconfirmed:\n track = unconfirmed[it]\n track.mark_removed()\n removed_stracks.append(track)\n # Step 4: Init new stracks\n for inew in u_detection:\n track = detections[inew]\n if track.score < self.args.new_track_thresh:\n continue\n track.activate(self.kalman_filter, self.frame_id)\n activated_stracks.append(track)\n # Step 5: Update state\n for track in self.lost_stracks:\n if self.frame_id - track.end_frame > self.max_time_lost:\n track.mark_removed()\n removed_stracks.append(track)\n\n self.tracked_stracks = [t for t in self.tracked_stracks if t.state == TrackState.Tracked]\n self.tracked_stracks = self.joint_stracks(self.tracked_stracks, activated_stracks)\n self.tracked_stracks = self.joint_stracks(self.tracked_stracks, refind_stracks)\n self.lost_stracks = self.sub_stracks(self.lost_stracks, self.tracked_stracks)\n self.lost_stracks.extend(lost_stracks)\n self.lost_stracks = self.sub_stracks(self.lost_stracks, self.removed_stracks)\n self.tracked_stracks, self.lost_stracks = self.remove_duplicate_stracks(self.tracked_stracks, self.lost_stracks)\n self.removed_stracks.extend(removed_stracks)\n if len(self.removed_stracks) > 1000:\n self.removed_stracks = self.removed_stracks[-999:] # clip remove stracks to 1000 maximum\n\n return np.asarray([x.result for x in self.tracked_stracks if x.is_activated], dtype=np.float32)\n\n def get_kalmanfilter(self) -> KalmanFilterXYAH:\n \"\"\"Return a Kalman filter object for tracking bounding boxes using KalmanFilterXYAH.\"\"\"\n return KalmanFilterXYAH()\n\n def init_track(\n self, dets: np.ndarray, scores: np.ndarray, cls: np.ndarray, img: Optional[np.ndarray] = None\n ) -> List[STrack]:\n \"\"\"Initialize object tracking with given detections, scores, and class labels using the STrack algorithm.\"\"\"\n return [STrack(xyxy, s, c) for (xyxy, s, c) in zip(dets, scores, cls)] if len(dets) else [] # detections\n\n def get_dists(self, tracks: List[STrack], detections: List[STrack]) -> np.ndarray:\n \"\"\"Calculate the distance between tracks and detections using IoU and optionally fuse scores.\"\"\"\n dists = matching.iou_distance(tracks, detections)\n if self.args.fuse_score:\n dists = matching.fuse_score(dists, detections)\n return dists\n\n def multi_predict(self, tracks: List[STrack]):\n \"\"\"Predict the next states for multiple tracks using Kalman filter.\"\"\"\n STrack.multi_predict(tracks)\n\n @staticmethod\n def reset_id():\n \"\"\"Reset the ID counter for STrack instances to ensure unique track IDs across tracking sessions.\"\"\"\n STrack.reset_id()\n\n def reset(self):\n \"\"\"Reset the tracker by clearing all tracked, lost, and removed tracks and reinitializing the Kalman filter.\"\"\"\n self.tracked_stracks = [] # type: List[STrack]\n self.lost_stracks = [] # type: List[STrack]\n self.removed_stracks = [] # type: List[STrack]\n self.frame_id = 0\n self.kalman_filter = self.get_kalmanfilter()\n self.reset_id()\n\n @staticmethod\n def joint_stracks(tlista: List[STrack], tlistb: List[STrack]) -> List[STrack]:\n \"\"\"Combine two lists of STrack objects into a single list, ensuring no duplicates based on track IDs.\"\"\"\n exists = {}\n res = []\n for t in tlista:\n exists[t.track_id] = 1\n res.append(t)\n for t in tlistb:\n tid = t.track_id\n if not exists.get(tid, 0):\n exists[tid] = 1\n res.append(t)\n return res\n\n @staticmethod\n def sub_stracks(tlista: List[STrack], tlistb: List[STrack]) -> List[STrack]:\n \"\"\"Filter out the stracks present in the second list from the first list.\"\"\"\n track_ids_b = {t.track_id for t in tlistb}\n return [t for t in tlista if t.track_id not in track_ids_b]\n\n @staticmethod\n def remove_duplicate_stracks(stracksa: List[STrack], stracksb: List[STrack]) -> Tuple[List[STrack], List[STrack]]:\n \"\"\"Remove duplicate stracks from two lists based on Intersection over Union (IoU) distance.\"\"\"\n pdist = matching.iou_distance(stracksa, stracksb)\n pairs = np.where(pdist < 0.15)\n dupa, dupb = [], []\n for p, q in zip(*pairs):\n timep = stracksa[p].frame_id - stracksa[p].start_frame\n timeq = stracksb[q].frame_id - stracksb[q].start_frame\n if timep > timeq:\n dupb.append(q)\n else:\n dupa.append(p)\n resa = [t for i, t in enumerate(stracksa) if i not in dupa]\n resb = [t for i, t in enumerate(stracksb) if i not in dupb]\n return resa, resb", "chunk_type": "class", "name": "BYTETracker", "file_path": "ultralytics\\ultralytics\\trackers\\byte_tracker.py", "start_line": 238, "end_line": 486, "start_col": 0, "end_col": 25, "parent_name": null, "docstring": "BYTETracker: A tracking algorithm built on top of YOLOv8 for object detection and tracking.\n\nThis class encapsulates the functionality for initializing, updating, and managing the tracks for detected objects in a\nvideo sequence. It maintains the state of tracked, lost, and removed tracks over frames, utilizes Kalman filtering for\npredicting the new object locations, and performs data association.\n\nAttributes:\n tracked_stracks (List[STrack]): List of successfully activated tracks.\n lost_stracks (List[STrack]): List of lost tracks.\n removed_stracks (List[STrack]): List of removed tracks.\n frame_id (int): The current frame ID.\n args (Namespace): Command-line arguments.\n max_time_lost (int): The maximum frames for a track to be considered as 'lost'.\n kalman_filter (KalmanFilterXYAH): Kalman Filter object.\n\nMethods:\n update: Update object tracker with new detections.\n get_kalmanfilter: Return a Kalman filter object for tracking bounding boxes.\n init_track: Initialize object tracking with detections.\n get_dists: Calculate the distance between tracks and detections.\n multi_predict: Predict the location of tracks.\n reset_id: Reset the ID counter of STrack.\n reset: Reset the tracker by clearing all tracks.\n joint_stracks: Combine two lists of stracks.\n sub_stracks: Filter out the stracks present in the second list from the first list.\n remove_duplicate_stracks: Remove duplicate stracks based on IoU.\n\nExamples:\n Initialize BYTETracker and update with detection results\n >>> tracker = BYTETracker(args, frame_rate=30)\n >>> results = yolo_model.detect(image)\n >>> tracked_objects = tracker.update(results)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.Any", "typing.List", "typing.Optional", "typing.Tuple", "numpy", "utils.LOGGER", "utils.ops.xywh2ltwh", "basetrack.BaseTrack", "basetrack.TrackState", "utils.matching", "utils.kalman_filter.KalmanFilterXYAH" ], "chunk_id": "class_BYTETracker_2c47ebb5" }, { "content": "from functools import partial", "chunk_type": "import", "name": "partial", "file_path": "ultralytics\\ultralytics\\trackers\\track.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 29, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_partial_e3539b4b" }, { "content": "from pathlib import Path", "chunk_type": "import", "name": "Path", "file_path": "ultralytics\\ultralytics\\trackers\\track.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Path_1d7eecc6" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\ultralytics\\trackers\\track.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_65d492ab" }, { "content": "from ultralytics.utils import YAML, IterableSimpleNamespace", "chunk_type": "import", "name": "YAML, IterableSimpleNamespace", "file_path": "ultralytics\\ultralytics\\trackers\\track.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 59, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_YAML, IterableSimpleNamespace_cd5d2008" }, { "content": "from ultralytics.utils.checks import check_yaml", "chunk_type": "import", "name": "check_yaml", "file_path": "ultralytics\\ultralytics\\trackers\\track.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 47, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_check_yaml_b0cae41a" }, { "content": "from .bot_sort import BOTSORT", "chunk_type": "import", "name": "BOTSORT", "file_path": "ultralytics\\ultralytics\\trackers\\track.py", "start_line": 11, "end_line": 11, "start_col": 0, "end_col": 29, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_BOTSORT_62cff40a" }, { "content": "from .byte_tracker import BYTETracker", "chunk_type": "import", "name": "BYTETracker", "file_path": "ultralytics\\ultralytics\\trackers\\track.py", "start_line": 12, "end_line": 12, "start_col": 0, "end_col": 37, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_BYTETracker_11b96f56" }, { "content": "TRACKER_MAP = {\"bytetrack\": BYTETracker, \"botsort\": BOTSORT}", "chunk_type": "variable", "name": "TRACKER_MAP", "file_path": "ultralytics\\ultralytics\\trackers\\track.py", "start_line": 15, "end_line": 15, "start_col": 0, "end_col": 60, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_TRACKER_MAP_31ce8c23" }, { "content": "def on_predict_start(predictor: object, persist: bool = False) -> None:\n \"\"\"\n Initialize trackers for object tracking during prediction.\n\n Args:\n predictor (ultralytics.engine.predictor.BasePredictor): The predictor object to initialize trackers for.\n persist (bool, optional): Whether to persist the trackers if they already exist.\n\n Examples:\n Initialize trackers for a predictor object\n >>> predictor = SomePredictorClass()\n >>> on_predict_start(predictor, persist=True)\n \"\"\"\n if predictor.args.task == \"classify\":\n raise ValueError(\"❌ Classification doesn't support 'mode=track'\")\n\n if hasattr(predictor, \"trackers\") and persist:\n return\n\n tracker = check_yaml(predictor.args.tracker)\n cfg = IterableSimpleNamespace(**YAML.load(tracker))\n\n if cfg.tracker_type not in {\"bytetrack\", \"botsort\"}:\n raise AssertionError(f\"Only 'bytetrack' and 'botsort' are supported for now, but got '{cfg.tracker_type}'\")\n\n predictor._feats = None # reset in case used earlier\n if hasattr(predictor, \"_hook\"):\n predictor._hook.remove()\n if cfg.tracker_type == \"botsort\" and cfg.with_reid and cfg.model == \"auto\":\n from ultralytics.nn.modules.head import Detect\n\n if not (\n isinstance(predictor.model.model, torch.nn.Module)\n and isinstance(predictor.model.model.model[-1], Detect)\n and not predictor.model.model.model[-1].end2end\n ):\n cfg.model = \"yolo11n-cls.pt\"\n else:\n # Register hook to extract input of Detect layer\n def pre_hook(module, input):\n predictor._feats = list(input[0]) # unroll to new list to avoid mutation in forward\n\n predictor._hook = predictor.model.model.model[-1].register_forward_pre_hook(pre_hook)\n\n trackers = []\n for _ in range(predictor.dataset.bs):\n tracker = TRACKER_MAP[cfg.tracker_type](args=cfg, frame_rate=30)\n trackers.append(tracker)\n if predictor.dataset.mode != \"stream\": # only need one tracker for other modes\n break\n predictor.trackers = trackers\n predictor.vid_path = [None] * predictor.dataset.bs # for determining when to reset tracker on new video", "chunk_type": "function", "name": "on_predict_start", "file_path": "ultralytics\\ultralytics\\trackers\\track.py", "start_line": 18, "end_line": 69, "start_col": 0, "end_col": 54, "parent_name": null, "docstring": "Initialize trackers for object tracking during prediction.\n\nArgs:\n predictor (ultralytics.engine.predictor.BasePredictor): The predictor object to initialize trackers for.\n persist (bool, optional): Whether to persist the trackers if they already exist.\n\nExamples:\n Initialize trackers for a predictor object\n >>> predictor = SomePredictorClass()\n >>> on_predict_start(predictor, persist=True)", "parameters": [ "predictor: object", "persist: bool" ], "return_type": "None", "decorators": [], "complexity_score": 9, "dependencies": [ "functools.partial", "pathlib.Path", "torch", "ultralytics.utils.YAML", "ultralytics.utils.IterableSimpleNamespace", "ultralytics.utils.checks.check_yaml", "bot_sort.BOTSORT", "byte_tracker.BYTETracker", "ultralytics.nn.modules.head.Detect" ], "chunk_id": "function_on_predict_start_b9207017" }, { "content": "def on_predict_postprocess_end(predictor: object, persist: bool = False) -> None:\n \"\"\"\n Postprocess detected boxes and update with object tracking.\n\n Args:\n predictor (object): The predictor object containing the predictions.\n persist (bool, optional): Whether to persist the trackers if they already exist.\n\n Examples:\n Postprocess predictions and update with tracking\n >>> predictor = YourPredictorClass()\n >>> on_predict_postprocess_end(predictor, persist=True)\n \"\"\"\n is_obb = predictor.args.task == \"obb\"\n is_stream = predictor.dataset.mode == \"stream\"\n for i, result in enumerate(predictor.results):\n tracker = predictor.trackers[i if is_stream else 0]\n vid_path = predictor.save_dir / Path(result.path).name\n if not persist and predictor.vid_path[i if is_stream else 0] != vid_path:\n tracker.reset()\n predictor.vid_path[i if is_stream else 0] = vid_path\n\n det = (result.obb if is_obb else result.boxes).cpu().numpy()\n tracks = tracker.update(det, result.orig_img, getattr(result, \"feats\", None))\n if len(tracks) == 0:\n continue\n idx = tracks[:, -1].astype(int)\n predictor.results[i] = result[idx]\n\n update_args = {\"obb\" if is_obb else \"boxes\": torch.as_tensor(tracks[:, :-1])}\n predictor.results[i].update(**update_args)", "chunk_type": "function", "name": "on_predict_postprocess_end", "file_path": "ultralytics\\ultralytics\\trackers\\track.py", "start_line": 72, "end_line": 102, "start_col": 0, "end_col": 50, "parent_name": null, "docstring": "Postprocess detected boxes and update with object tracking.\n\nArgs:\n predictor (object): The predictor object containing the predictions.\n persist (bool, optional): Whether to persist the trackers if they already exist.\n\nExamples:\n Postprocess predictions and update with tracking\n >>> predictor = YourPredictorClass()\n >>> on_predict_postprocess_end(predictor, persist=True)", "parameters": [ "predictor: object", "persist: bool" ], "return_type": "None", "decorators": [], "complexity_score": 4, "dependencies": [ "functools.partial", "pathlib.Path", "torch", "ultralytics.utils.YAML", "ultralytics.utils.IterableSimpleNamespace", "ultralytics.utils.checks.check_yaml", "bot_sort.BOTSORT", "byte_tracker.BYTETracker", "ultralytics.nn.modules.head.Detect" ], "chunk_id": "function_on_predict_postprocess_end_6cfbae9e" }, { "content": "def register_tracker(model: object, persist: bool) -> None:\n \"\"\"\n Register tracking callbacks to the model for object tracking during prediction.\n\n Args:\n model (object): The model object to register tracking callbacks for.\n persist (bool): Whether to persist the trackers if they already exist.\n\n Examples:\n Register tracking callbacks to a YOLO model\n >>> model = YOLOModel()\n >>> register_tracker(model, persist=True)\n \"\"\"\n model.add_callback(\"on_predict_start\", partial(on_predict_start, persist=persist))\n model.add_callback(\"on_predict_postprocess_end\", partial(on_predict_postprocess_end, persist=persist))", "chunk_type": "function", "name": "register_tracker", "file_path": "ultralytics\\ultralytics\\trackers\\track.py", "start_line": 105, "end_line": 119, "start_col": 0, "end_col": 106, "parent_name": null, "docstring": "Register tracking callbacks to the model for object tracking during prediction.\n\nArgs:\n model (object): The model object to register tracking callbacks for.\n persist (bool): Whether to persist the trackers if they already exist.\n\nExamples:\n Register tracking callbacks to a YOLO model\n >>> model = YOLOModel()\n >>> register_tracker(model, persist=True)", "parameters": [ "model: object", "persist: bool" ], "return_type": "None", "decorators": [], "complexity_score": 1, "dependencies": [ "functools.partial", "pathlib.Path", "torch", "ultralytics.utils.YAML", "ultralytics.utils.IterableSimpleNamespace", "ultralytics.utils.checks.check_yaml", "bot_sort.BOTSORT", "byte_tracker.BYTETracker", "ultralytics.nn.modules.head.Detect" ], "chunk_id": "function_register_tracker_a54d453d" }, { "content": "from .bot_sort import BOTSORT", "chunk_type": "import", "name": "BOTSORT", "file_path": "ultralytics\\ultralytics\\trackers\\__init__.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 29, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_BOTSORT_cca8eb05" }, { "content": "from .byte_tracker import BYTETracker", "chunk_type": "import", "name": "BYTETracker", "file_path": "ultralytics\\ultralytics\\trackers\\__init__.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 37, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_BYTETracker_15e97245" }, { "content": "from .track import register_tracker", "chunk_type": "import", "name": "register_tracker", "file_path": "ultralytics\\ultralytics\\trackers\\__init__.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 35, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_register_tracker_4ebabca9" }, { "content": "__all__ = \"register_tracker\", \"BOTSORT\", \"BYTETracker\" # allow simpler import", "chunk_type": "variable", "name": "__all__", "file_path": "ultralytics\\ultralytics\\trackers\\__init__.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 54, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable___all___6e5980f6" }, { "content": "import os", "chunk_type": "import", "name": "os", "file_path": "ultralytics\\ultralytics\\utils\\autobatch.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 9, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_os_79cd52cc" }, { "content": "from copy import deepcopy", "chunk_type": "import", "name": "deepcopy", "file_path": "ultralytics\\ultralytics\\utils\\autobatch.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 25, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_deepcopy_a261682b" }, { "content": "from typing import Union", "chunk_type": "import", "name": "Union", "file_path": "ultralytics\\ultralytics\\utils\\autobatch.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Union_a82fbc39" }, { "content": "import numpy as np", "chunk_type": "import", "name": "numpy", "file_path": "ultralytics\\ultralytics\\utils\\autobatch.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_numpy_17b39cfb" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\ultralytics\\utils\\autobatch.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_cc39b187" }, { "content": "from ultralytics.utils import DEFAULT_CFG, LOGGER, colorstr", "chunk_type": "import", "name": "DEFAULT_CFG, LOGGER, colorstr", "file_path": "ultralytics\\ultralytics\\utils\\autobatch.py", "start_line": 11, "end_line": 11, "start_col": 0, "end_col": 59, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_DEFAULT_CFG, LOGGER, colorstr_765852ba" }, { "content": "from ultralytics.utils.torch_utils import autocast, profile_ops", "chunk_type": "import", "name": "autocast, profile_ops", "file_path": "ultralytics\\ultralytics\\utils\\autobatch.py", "start_line": 12, "end_line": 12, "start_col": 0, "end_col": 63, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_autocast, profile_ops_8fe6c7a2" }, { "content": "def check_train_batch_size(\n model: torch.nn.Module,\n imgsz: int = 640,\n amp: bool = True,\n batch: Union[int, float] = -1,\n max_num_obj: int = 1,\n) -> int:\n \"\"\"\n Compute optimal YOLO training batch size using the autobatch() function.\n\n Args:\n model (torch.nn.Module): YOLO model to check batch size for.\n imgsz (int, optional): Image size used for training.\n amp (bool, optional): Use automatic mixed precision if True.\n batch (int | float, optional): Fraction of GPU memory to use. If -1, use default.\n max_num_obj (int, optional): The maximum number of objects from dataset.\n\n Returns:\n (int): Optimal batch size computed using the autobatch() function.\n\n Notes:\n If 0.0 < batch < 1.0, it's used as the fraction of GPU memory to use.\n Otherwise, a default fraction of 0.6 is used.\n \"\"\"\n with autocast(enabled=amp):\n return autobatch(\n deepcopy(model).train(), imgsz, fraction=batch if 0.0 < batch < 1.0 else 0.6, max_num_obj=max_num_obj\n )", "chunk_type": "function", "name": "check_train_batch_size", "file_path": "ultralytics\\ultralytics\\utils\\autobatch.py", "start_line": 15, "end_line": 42, "start_col": 0, "end_col": 9, "parent_name": null, "docstring": "Compute optimal YOLO training batch size using the autobatch() function.\n\nArgs:\n model (torch.nn.Module): YOLO model to check batch size for.\n imgsz (int, optional): Image size used for training.\n amp (bool, optional): Use automatic mixed precision if True.\n batch (int | float, optional): Fraction of GPU memory to use. If -1, use default.\n max_num_obj (int, optional): The maximum number of objects from dataset.\n\nReturns:\n (int): Optimal batch size computed using the autobatch() function.\n\nNotes:\n If 0.0 < batch < 1.0, it's used as the fraction of GPU memory to use.\n Otherwise, a default fraction of 0.6 is used.", "parameters": [ "model: torch.nn.Module", "imgsz: int", "amp: bool", "batch: Union[int, float]", "max_num_obj: int" ], "return_type": "int", "decorators": [], "complexity_score": 1, "dependencies": [ "os", "copy.deepcopy", "typing.Union", "numpy", "torch", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.LOGGER", "ultralytics.utils.colorstr", "ultralytics.utils.torch_utils.autocast", "ultralytics.utils.torch_utils.profile_ops" ], "chunk_id": "function_check_train_batch_size_ae6db0ae" }, { "content": "def autobatch(\n model: torch.nn.Module,\n imgsz: int = 640,\n fraction: float = 0.60,\n batch_size: int = DEFAULT_CFG.batch,\n max_num_obj: int = 1,\n) -> int:\n \"\"\"\n Automatically estimate the best YOLO batch size to use a fraction of the available CUDA memory.\n\n Args:\n model (torch.nn.Module): YOLO model to compute batch size for.\n imgsz (int, optional): The image size used as input for the YOLO model.\n fraction (float, optional): The fraction of available CUDA memory to use.\n batch_size (int, optional): The default batch size to use if an error is detected.\n max_num_obj (int, optional): The maximum number of objects from dataset.\n\n Returns:\n (int): The optimal batch size.\n \"\"\"\n # Check device\n prefix = colorstr(\"AutoBatch: \")\n LOGGER.info(f\"{prefix}Computing optimal batch size for imgsz={imgsz} at {fraction * 100}% CUDA memory utilization.\")\n device = next(model.parameters()).device # get model device\n if device.type in {\"cpu\", \"mps\"}:\n LOGGER.warning(f\"{prefix}intended for CUDA devices, using default batch-size {batch_size}\")\n return batch_size\n if torch.backends.cudnn.benchmark:\n LOGGER.warning(f\"{prefix}Requires torch.backends.cudnn.benchmark=False, using default batch-size {batch_size}\")\n return batch_size\n\n # Inspect CUDA memory\n gb = 1 << 30 # bytes to GiB (1024 ** 3)\n d = f\"CUDA:{os.getenv('CUDA_VISIBLE_DEVICES', '0').strip()[0]}\" # 'CUDA:0'\n properties = torch.cuda.get_device_properties(device) # device properties\n t = properties.total_memory / gb # GiB total\n r = torch.cuda.memory_reserved(device) / gb # GiB reserved\n a = torch.cuda.memory_allocated(device) / gb # GiB allocated\n f = t - (r + a) # GiB free\n LOGGER.info(f\"{prefix}{d} ({properties.name}) {t:.2f}G total, {r:.2f}G reserved, {a:.2f}G allocated, {f:.2f}G free\")\n\n # Profile batch sizes\n batch_sizes = [1, 2, 4, 8, 16] if t < 16 else [1, 2, 4, 8, 16, 32, 64]\n try:\n img = [torch.empty(b, 3, imgsz, imgsz) for b in batch_sizes]\n results = profile_ops(img, model, n=1, device=device, max_num_obj=max_num_obj)\n\n # Fit a solution\n xy = [\n [x, y[2]]\n for i, (x, y) in enumerate(zip(batch_sizes, results))\n if y # valid result\n and isinstance(y[2], (int, float)) # is numeric\n and 0 < y[2] < t # between 0 and GPU limit\n and (i == 0 or not results[i - 1] or y[2] > results[i - 1][2]) # first item or increasing memory\n ]\n fit_x, fit_y = zip(*xy) if xy else ([], [])\n p = np.polyfit(fit_x, fit_y, deg=1) # first-degree polynomial fit in log space\n b = int((round(f * fraction) - p[1]) / p[0]) # y intercept (optimal batch size)\n if None in results: # some sizes failed\n i = results.index(None) # first fail index\n if b >= batch_sizes[i]: # y intercept above failure point\n b = batch_sizes[max(i - 1, 0)] # select prior safe point\n if b < 1 or b > 1024: # b outside of safe range\n LOGGER.warning(f\"{prefix}batch={b} outside safe range, using default batch-size {batch_size}.\")\n b = batch_size\n\n fraction = (np.polyval(p, b) + r + a) / t # predicted fraction\n LOGGER.info(f\"{prefix}Using batch-size {b} for {d} {t * fraction:.2f}G/{t:.2f}G ({fraction * 100:.0f}%) ✅\")\n return b\n except Exception as e:\n LOGGER.warning(f\"{prefix}error detected: {e}, using default batch-size {batch_size}.\")\n return batch_size\n finally:\n torch.cuda.empty_cache()", "chunk_type": "function", "name": "autobatch", "file_path": "ultralytics\\ultralytics\\utils\\autobatch.py", "start_line": 45, "end_line": 119, "start_col": 0, "end_col": 32, "parent_name": null, "docstring": "Automatically estimate the best YOLO batch size to use a fraction of the available CUDA memory.\n\nArgs:\n model (torch.nn.Module): YOLO model to compute batch size for.\n imgsz (int, optional): The image size used as input for the YOLO model.\n fraction (float, optional): The fraction of available CUDA memory to use.\n batch_size (int, optional): The default batch size to use if an error is detected.\n max_num_obj (int, optional): The maximum number of objects from dataset.\n\nReturns:\n (int): The optimal batch size.", "parameters": [ "model: torch.nn.Module", "imgsz: int", "fraction: float", "batch_size: int", "max_num_obj: int" ], "return_type": "int", "decorators": [], "complexity_score": 9, "dependencies": [ "os", "copy.deepcopy", "typing.Union", "numpy", "torch", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.LOGGER", "ultralytics.utils.colorstr", "ultralytics.utils.torch_utils.autocast", "ultralytics.utils.torch_utils.profile_ops" ], "chunk_id": "function_autobatch_582df94b" }, { "content": "from typing import Any, Dict, List, Optional", "chunk_type": "import", "name": "Any, Dict, List, Optional", "file_path": "ultralytics\\ultralytics\\utils\\autodevice.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 44, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Any, Dict, List, Optional_805eb627" }, { "content": "from ultralytics.utils import LOGGER", "chunk_type": "import", "name": "LOGGER", "file_path": "ultralytics\\ultralytics\\utils\\autodevice.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 36, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_LOGGER_6ae07218" }, { "content": "from ultralytics.utils.checks import check_requirements", "chunk_type": "import", "name": "check_requirements", "file_path": "ultralytics\\ultralytics\\utils\\autodevice.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 55, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_check_requirements_24e0b7a5" }, { "content": "class GPUInfo:\n \"\"\"\n Manages NVIDIA GPU information via pynvml with robust error handling.\n\n Provides methods to query detailed GPU statistics (utilization, memory, temp, power) and select the most idle\n GPUs based on configurable criteria. It safely handles the absence or initialization failure of the pynvml\n library by logging warnings and disabling related features, preventing application crashes.\n\n Includes fallback logic using `torch.cuda` for basic device counting if NVML is unavailable during GPU\n selection. Manages NVML initialization and shutdown internally.\n\n Attributes:\n pynvml (module | None): The `pynvml` module if successfully imported and initialized, otherwise `None`.\n nvml_available (bool): Indicates if `pynvml` is ready for use. True if import and `nvmlInit()` succeeded,\n False otherwise.\n gpu_stats (List[Dict[str, Any]]): A list of dictionaries, each holding stats for one GPU. Populated on\n initialization and by `refresh_stats()`. Keys include: 'index', 'name', 'utilization' (%),\n 'memory_used' (MiB), 'memory_total' (MiB), 'memory_free' (MiB), 'temperature' (C), 'power_draw' (W),\n 'power_limit' (W or 'N/A'). Empty if NVML is unavailable or queries fail.\n\n Methods:\n refresh_stats: Refresh the internal gpu_stats list by querying NVML.\n print_status: Print GPU status in a compact table format using current stats.\n select_idle_gpu: Select the most idle GPUs based on utilization and free memory.\n shutdown: Shut down NVML if it was initialized.\n\n Examples:\n Initialize GPUInfo and print status\n >>> gpu_info = GPUInfo()\n >>> gpu_info.print_status()\n\n Select idle GPUs with minimum memory requirements\n >>> selected = gpu_info.select_idle_gpu(count=2, min_memory_fraction=0.2)\n >>> print(f\"Selected GPU indices: {selected}\")\n \"\"\"\n\n def __init__(self):\n \"\"\"Initialize GPUInfo, attempting to import and initialize pynvml.\"\"\"\n self.pynvml: Optional[Any] = None\n self.nvml_available: bool = False\n self.gpu_stats: List[Dict[str, Any]] = []\n\n try:\n check_requirements(\"pynvml>=12.0.0\")\n self.pynvml = __import__(\"pynvml\")\n self.pynvml.nvmlInit()\n self.nvml_available = True\n self.refresh_stats()\n except Exception as e:\n LOGGER.warning(f\"Failed to initialize pynvml, GPU stats disabled: {e}\")\n\n def __del__(self):\n \"\"\"Ensure NVML is shut down when the object is garbage collected.\"\"\"\n self.shutdown()\n\n def shutdown(self):\n \"\"\"Shut down NVML if it was initialized.\"\"\"\n if self.nvml_available and self.pynvml:\n try:\n self.pynvml.nvmlShutdown()\n except Exception:\n pass\n self.nvml_available = False\n\n def refresh_stats(self):\n \"\"\"Refresh the internal gpu_stats list by querying NVML.\"\"\"\n self.gpu_stats = []\n if not self.nvml_available or not self.pynvml:\n return\n\n try:\n device_count = self.pynvml.nvmlDeviceGetCount()\n for i in range(device_count):\n self.gpu_stats.append(self._get_device_stats(i))\n except Exception as e:\n LOGGER.warning(f\"Error during device query: {e}\")\n self.gpu_stats = []\n\n def _get_device_stats(self, index: int) -> Dict[str, Any]:\n \"\"\"Get stats for a single GPU device.\"\"\"\n handle = self.pynvml.nvmlDeviceGetHandleByIndex(index)\n memory = self.pynvml.nvmlDeviceGetMemoryInfo(handle)\n util = self.pynvml.nvmlDeviceGetUtilizationRates(handle)\n\n def safe_get(func, *args, default=-1, divisor=1):\n try:\n val = func(*args)\n return val // divisor if divisor != 1 and isinstance(val, (int, float)) else val\n except Exception:\n return default\n\n temp_type = getattr(self.pynvml, \"NVML_TEMPERATURE_GPU\", -1)\n\n return {\n \"index\": index,\n \"name\": self.pynvml.nvmlDeviceGetName(handle),\n \"utilization\": util.gpu if util else -1,\n \"memory_used\": memory.used >> 20 if memory else -1, # Convert bytes to MiB\n \"memory_total\": memory.total >> 20 if memory else -1,\n \"memory_free\": memory.free >> 20 if memory else -1,\n \"temperature\": safe_get(self.pynvml.nvmlDeviceGetTemperature, handle, temp_type),\n \"power_draw\": safe_get(self.pynvml.nvmlDeviceGetPowerUsage, handle, divisor=1000), # Convert mW to W\n \"power_limit\": safe_get(self.pynvml.nvmlDeviceGetEnforcedPowerLimit, handle, divisor=1000),\n }\n\n def print_status(self):\n \"\"\"Print GPU status in a compact table format using current stats.\"\"\"\n self.refresh_stats()\n if not self.gpu_stats:\n LOGGER.warning(\"No GPU stats available.\")\n return\n\n stats = self.gpu_stats\n name_len = max(len(gpu.get(\"name\", \"N/A\")) for gpu in stats)\n hdr = f\"{'Idx':<3} {'Name':<{name_len}} {'Util':>6} {'Mem (MiB)':>15} {'Temp':>5} {'Pwr (W)':>10}\"\n LOGGER.info(f\"\\n--- GPU Status ---\\n{hdr}\\n{'-' * len(hdr)}\")\n\n for gpu in stats:\n u = f\"{gpu['utilization']:>5}%\" if gpu[\"utilization\"] >= 0 else \" N/A \"\n m = f\"{gpu['memory_used']:>6}/{gpu['memory_total']:<6}\" if gpu[\"memory_used\"] >= 0 else \" N/A / N/A \"\n t = f\"{gpu['temperature']}C\" if gpu[\"temperature\"] >= 0 else \" N/A \"\n p = f\"{gpu['power_draw']:>3}/{gpu['power_limit']:<3}\" if gpu[\"power_draw\"] >= 0 else \" N/A \"\n\n LOGGER.info(f\"{gpu.get('index'):<3d} {gpu.get('name', 'N/A'):<{name_len}} {u:>6} {m:>15} {t:>5} {p:>10}\")\n\n LOGGER.info(f\"{'-' * len(hdr)}\\n\")\n\n def select_idle_gpu(\n self, count: int = 1, min_memory_fraction: float = 0, min_util_fraction: float = 0\n ) -> List[int]:\n \"\"\"\n Select the most idle GPUs based on utilization and free memory.\n\n Args:\n count (int): The number of idle GPUs to select.\n min_memory_fraction (float): Minimum free memory required as a fraction of total memory.\n min_util_fraction (float): Minimum free utilization rate required from 0.0 - 1.0.\n\n Returns:\n (List[int]): Indices of the selected GPUs, sorted by idleness (lowest utilization first).\n\n Notes:\n Returns fewer than 'count' if not enough qualify or exist.\n Returns basic CUDA indices if NVML fails. Empty list if no GPUs found.\n \"\"\"\n assert min_memory_fraction <= 1.0, f\"min_memory_fraction must be <= 1.0, got {min_memory_fraction}\"\n assert min_util_fraction <= 1.0, f\"min_util_fraction must be <= 1.0, got {min_util_fraction}\"\n LOGGER.info(\n f\"Searching for {count} idle GPUs with free memory >= {min_memory_fraction * 100:.1f}% and free utilization >= {min_util_fraction * 100:.1f}%...\"\n )\n\n if count <= 0:\n return []\n\n self.refresh_stats()\n if not self.gpu_stats:\n LOGGER.warning(\"NVML stats unavailable.\")\n return []\n\n # Filter and sort eligible GPUs\n eligible_gpus = [\n gpu\n for gpu in self.gpu_stats\n if gpu.get(\"memory_free\", 0) / gpu.get(\"memory_total\", 1) >= min_memory_fraction\n and (100 - gpu.get(\"utilization\", 100)) >= min_util_fraction * 100\n ]\n eligible_gpus.sort(key=lambda x: (x.get(\"utilization\", 101), -x.get(\"memory_free\", 0)))\n\n # Select top 'count' indices\n selected = [gpu[\"index\"] for gpu in eligible_gpus[:count]]\n\n if selected:\n LOGGER.info(f\"Selected idle CUDA devices {selected}\")\n else:\n LOGGER.warning(\n f\"No GPUs met criteria (Free Mem >= {min_memory_fraction * 100:.1f}% and Free Util >= {min_util_fraction * 100:.1f}%).\"\n )\n\n return selected", "chunk_type": "class", "name": "GPUInfo", "file_path": "ultralytics\\ultralytics\\utils\\autodevice.py", "start_line": 9, "end_line": 187, "start_col": 0, "end_col": 23, "parent_name": null, "docstring": "Manages NVIDIA GPU information via pynvml with robust error handling.\n\nProvides methods to query detailed GPU statistics (utilization, memory, temp, power) and select the most idle\nGPUs based on configurable criteria. It safely handles the absence or initialization failure of the pynvml\nlibrary by logging warnings and disabling related features, preventing application crashes.\n\nIncludes fallback logic using `torch.cuda` for basic device counting if NVML is unavailable during GPU\nselection. Manages NVML initialization and shutdown internally.\n\nAttributes:\n pynvml (module | None): The `pynvml` module if successfully imported and initialized, otherwise `None`.\n nvml_available (bool): Indicates if `pynvml` is ready for use. True if import and `nvmlInit()` succeeded,\n False otherwise.\n gpu_stats (List[Dict[str, Any]]): A list of dictionaries, each holding stats for one GPU. Populated on\n initialization and by `refresh_stats()`. Keys include: 'index', 'name', 'utilization' (%),\n 'memory_used' (MiB), 'memory_total' (MiB), 'memory_free' (MiB), 'temperature' (C), 'power_draw' (W),\n 'power_limit' (W or 'N/A'). Empty if NVML is unavailable or queries fail.\n\nMethods:\n refresh_stats: Refresh the internal gpu_stats list by querying NVML.\n print_status: Print GPU status in a compact table format using current stats.\n select_idle_gpu: Select the most idle GPUs based on utilization and free memory.\n shutdown: Shut down NVML if it was initialized.\n\nExamples:\n Initialize GPUInfo and print status\n >>> gpu_info = GPUInfo()\n >>> gpu_info.print_status()\n\n Select idle GPUs with minimum memory requirements\n >>> selected = gpu_info.select_idle_gpu(count=2, min_memory_fraction=0.2)\n >>> print(f\"Selected GPU indices: {selected}\")", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.Any", "typing.Dict", "typing.List", "typing.Optional", "ultralytics.utils.LOGGER", "ultralytics.utils.checks.check_requirements" ], "chunk_id": "class_GPUInfo_e0b9c1fd" }, { "content": "import glob", "chunk_type": "import", "name": "glob", "file_path": "ultralytics\\ultralytics\\utils\\benchmarks.py", "start_line": 30, "end_line": 30, "start_col": 0, "end_col": 11, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_glob_daf13c20" }, { "content": "import os", "chunk_type": "import", "name": "os", "file_path": "ultralytics\\ultralytics\\utils\\benchmarks.py", "start_line": 31, "end_line": 31, "start_col": 0, "end_col": 9, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_os_f2fede17" }, { "content": "import platform", "chunk_type": "import", "name": "platform", "file_path": "ultralytics\\ultralytics\\utils\\benchmarks.py", "start_line": 32, "end_line": 32, "start_col": 0, "end_col": 15, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_platform_a51b6978" }, { "content": "import re", "chunk_type": "import", "name": "re", "file_path": "ultralytics\\ultralytics\\utils\\benchmarks.py", "start_line": 33, "end_line": 33, "start_col": 0, "end_col": 9, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_re_7f62a565" }, { "content": "import shutil", "chunk_type": "import", "name": "shutil", "file_path": "ultralytics\\ultralytics\\utils\\benchmarks.py", "start_line": 34, "end_line": 34, "start_col": 0, "end_col": 13, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_shutil_0ce18268" }, { "content": "import time", "chunk_type": "import", "name": "time", "file_path": "ultralytics\\ultralytics\\utils\\benchmarks.py", "start_line": 35, "end_line": 35, "start_col": 0, "end_col": 11, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_time_7ca1fa09" }, { "content": "from pathlib import Path", "chunk_type": "import", "name": "Path", "file_path": "ultralytics\\ultralytics\\utils\\benchmarks.py", "start_line": 36, "end_line": 36, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Path_404f7d8b" }, { "content": "from typing import List, Optional, Tuple, Union", "chunk_type": "import", "name": "List, Optional, Tuple, Union", "file_path": "ultralytics\\ultralytics\\utils\\benchmarks.py", "start_line": 37, "end_line": 37, "start_col": 0, "end_col": 47, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_List, Optional, Tuple, Union_379ccf47" }, { "content": "import numpy as np", "chunk_type": "import", "name": "numpy", "file_path": "ultralytics\\ultralytics\\utils\\benchmarks.py", "start_line": 39, "end_line": 39, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_numpy_0143459b" }, { "content": "import torch.cuda", "chunk_type": "import", "name": "torch.cuda", "file_path": "ultralytics\\ultralytics\\utils\\benchmarks.py", "start_line": 40, "end_line": 40, "start_col": 0, "end_col": 17, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch.cuda_8e6ff85d" }, { "content": "from ultralytics import YOLO, YOLOWorld", "chunk_type": "import", "name": "YOLO, YOLOWorld", "file_path": "ultralytics\\ultralytics\\utils\\benchmarks.py", "start_line": 42, "end_line": 42, "start_col": 0, "end_col": 39, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_YOLO, YOLOWorld_48830e91" }, { "content": "from ultralytics.cfg import TASK2DATA, TASK2METRIC", "chunk_type": "import", "name": "TASK2DATA, TASK2METRIC", "file_path": "ultralytics\\ultralytics\\utils\\benchmarks.py", "start_line": 43, "end_line": 43, "start_col": 0, "end_col": 50, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_TASK2DATA, TASK2METRIC_5a2739ff" }, { "content": "from ultralytics.engine.exporter import export_formats", "chunk_type": "import", "name": "export_formats", "file_path": "ultralytics\\ultralytics\\utils\\benchmarks.py", "start_line": 44, "end_line": 44, "start_col": 0, "end_col": 54, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_export_formats_39ba1379" }, { "content": "from ultralytics.utils import ARM64, ASSETS, IS_JETSON, LINUX, LOGGER, MACOS, TQDM, WEIGHTS_DIR, YAML", "chunk_type": "import", "name": "ARM64, ASSETS, IS_JETSON, LINUX, LOGGER, MACOS, TQDM, WEIGHTS_DIR, YAML", "file_path": "ultralytics\\ultralytics\\utils\\benchmarks.py", "start_line": 45, "end_line": 45, "start_col": 0, "end_col": 101, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_ARM64, ASSETS, IS_JETSON, LINUX, LOGGER, MACOS, TQDM, WEIGHTS_DIR, YAML_36d5d8bd" }, { "content": "from ultralytics.utils.checks import IS_PYTHON_3_13, check_imgsz, check_requirements, check_yolo, is_rockchip", "chunk_type": "import", "name": "IS_PYTHON_3_13, check_imgsz, check_requirements, check_yolo, is_rockchip", "file_path": "ultralytics\\ultralytics\\utils\\benchmarks.py", "start_line": 46, "end_line": 46, "start_col": 0, "end_col": 109, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_IS_PYTHON_3_13, check_imgsz, check_requirements, check_yolo, is_rockchip_2ed1aae2" }, { "content": "from ultralytics.utils.downloads import safe_download", "chunk_type": "import", "name": "safe_download", "file_path": "ultralytics\\ultralytics\\utils\\benchmarks.py", "start_line": 47, "end_line": 47, "start_col": 0, "end_col": 53, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_safe_download_15b261a3" }, { "content": "from ultralytics.utils.files import file_size", "chunk_type": "import", "name": "file_size", "file_path": "ultralytics\\ultralytics\\utils\\benchmarks.py", "start_line": 48, "end_line": 48, "start_col": 0, "end_col": 45, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_file_size_aae2992d" }, { "content": "from ultralytics.utils.torch_utils import get_cpu_info, select_device", "chunk_type": "import", "name": "get_cpu_info, select_device", "file_path": "ultralytics\\ultralytics\\utils\\benchmarks.py", "start_line": 49, "end_line": 49, "start_col": 0, "end_col": 69, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_get_cpu_info, select_device_e4e5353e" }, { "content": "def benchmark(\n model=WEIGHTS_DIR / \"yolo11n.pt\",\n data=None,\n imgsz=160,\n half=False,\n int8=False,\n device=\"cpu\",\n verbose=False,\n eps=1e-3,\n format=\"\",\n **kwargs,\n):\n \"\"\"\n Benchmark a YOLO model across different formats for speed and accuracy.\n\n Args:\n model (str | Path): Path to the model file or directory.\n data (str | None): Dataset to evaluate on, inherited from TASK2DATA if not passed.\n imgsz (int): Image size for the benchmark.\n half (bool): Use half-precision for the model if True.\n int8 (bool): Use int8-precision for the model if True.\n device (str): Device to run the benchmark on, either 'cpu' or 'cuda'.\n verbose (bool | float): If True or a float, assert benchmarks pass with given metric.\n eps (float): Epsilon value for divide by zero prevention.\n format (str): Export format for benchmarking. If not supplied all formats are benchmarked.\n **kwargs (Any): Additional keyword arguments for exporter.\n\n Returns:\n (pandas.DataFrame): A pandas DataFrame with benchmark results for each format, including file size, metric,\n and inference time.\n\n Examples:\n Benchmark a YOLO model with default settings:\n >>> from ultralytics.utils.benchmarks import benchmark\n >>> benchmark(model=\"yolo11n.pt\", imgsz=640)\n \"\"\"\n imgsz = check_imgsz(imgsz)\n assert imgsz[0] == imgsz[1] if isinstance(imgsz, list) else True, \"benchmark() only supports square imgsz.\"\n\n import pandas as pd # scope for faster 'import ultralytics'\n\n pd.options.display.max_columns = 10\n pd.options.display.width = 120\n device = select_device(device, verbose=False)\n if isinstance(model, (str, Path)):\n model = YOLO(model)\n is_end2end = getattr(model.model.model[-1], \"end2end\", False)\n data = data or TASK2DATA[model.task] # task to dataset, i.e. coco8.yaml for task=detect\n key = TASK2METRIC[model.task] # task to metric, i.e. metrics/mAP50-95(B) for task=detect\n\n y = []\n t0 = time.time()\n\n format_arg = format.lower()\n if format_arg:\n formats = frozenset(export_formats()[\"Argument\"])\n assert format in formats, f\"Expected format to be one of {formats}, but got '{format_arg}'.\"\n for name, format, suffix, cpu, gpu, _ in zip(*export_formats().values()):\n emoji, filename = \"❌\", None # export defaults\n try:\n if format_arg and format_arg != format:\n continue\n\n # Checks\n if format == \"pb\":\n assert model.task != \"obb\", \"TensorFlow GraphDef not supported for OBB task\"\n elif format == \"edgetpu\":\n assert LINUX and not ARM64, \"Edge TPU export only supported on non-aarch64 Linux\"\n elif format in {\"coreml\", \"tfjs\"}:\n assert MACOS or (LINUX and not ARM64), (\n \"CoreML and TF.js export only supported on macOS and non-aarch64 Linux\"\n )\n if format == \"coreml\":\n assert not IS_PYTHON_3_13, \"CoreML not supported on Python 3.13\"\n if format in {\"saved_model\", \"pb\", \"tflite\", \"edgetpu\", \"tfjs\"}:\n assert not isinstance(model, YOLOWorld), \"YOLOWorldv2 TensorFlow exports not supported by onnx2tf yet\"\n # assert not IS_PYTHON_MINIMUM_3_12, \"TFLite exports not supported on Python>=3.12 yet\"\n if format == \"paddle\":\n assert not isinstance(model, YOLOWorld), \"YOLOWorldv2 Paddle exports not supported yet\"\n assert model.task != \"obb\", \"Paddle OBB bug https://github.com/PaddlePaddle/Paddle/issues/72024\"\n assert not is_end2end, \"End-to-end models not supported by PaddlePaddle yet\"\n assert (LINUX and not IS_JETSON) or MACOS, \"Windows and Jetson Paddle exports not supported yet\"\n if format == \"mnn\":\n assert not isinstance(model, YOLOWorld), \"YOLOWorldv2 MNN exports not supported yet\"\n if format == \"ncnn\":\n assert not isinstance(model, YOLOWorld), \"YOLOWorldv2 NCNN exports not supported yet\"\n if format == \"imx\":\n assert not is_end2end\n assert not isinstance(model, YOLOWorld), \"YOLOWorldv2 IMX exports not supported\"\n assert model.task == \"detect\", \"IMX only supported for detection task\"\n assert \"C2f\" in model.__str__(), \"IMX only supported for YOLOv8\" # TODO: enable for YOLO11\n if format == \"rknn\":\n assert not isinstance(model, YOLOWorld), \"YOLOWorldv2 RKNN exports not supported yet\"\n assert not is_end2end, \"End-to-end models not supported by RKNN yet\"\n assert LINUX, \"RKNN only supported on Linux\"\n assert not is_rockchip(), \"RKNN Inference only supported on Rockchip devices\"\n if \"cpu\" in device.type:\n assert cpu, \"inference not supported on CPU\"\n if \"cuda\" in device.type:\n assert gpu, \"inference not supported on GPU\"\n\n # Export\n if format == \"-\":\n filename = model.pt_path or model.ckpt_path or model.model_name\n exported_model = model # PyTorch format\n else:\n filename = model.export(\n imgsz=imgsz, format=format, half=half, int8=int8, data=data, device=device, verbose=False, **kwargs\n )\n exported_model = YOLO(filename, task=model.task)\n assert suffix in str(filename), \"export failed\"\n emoji = \"❎\" # indicates export succeeded\n\n # Predict\n assert model.task != \"pose\" or format != \"pb\", \"GraphDef Pose inference is not supported\"\n assert format not in {\"edgetpu\", \"tfjs\"}, \"inference not supported\"\n assert format != \"coreml\" or platform.system() == \"Darwin\", \"inference only supported on macOS>=10.13\"\n if format == \"ncnn\":\n assert not is_end2end, \"End-to-end torch.topk operation is not supported for NCNN prediction yet\"\n exported_model.predict(ASSETS / \"bus.jpg\", imgsz=imgsz, device=device, half=half, verbose=False)\n\n # Validate\n results = exported_model.val(\n data=data,\n batch=1,\n imgsz=imgsz,\n plots=False,\n device=device,\n half=half,\n int8=int8,\n verbose=False,\n conf=0.001, # all the pre-set benchmark mAP values are based on conf=0.001\n )\n metric, speed = results.results_dict[key], results.speed[\"inference\"]\n fps = round(1000 / (speed + eps), 2) # frames per second\n y.append([name, \"✅\", round(file_size(filename), 1), round(metric, 4), round(speed, 2), fps])\n except Exception as e:\n if verbose:\n assert type(e) is AssertionError, f\"Benchmark failure for {name}: {e}\"\n LOGGER.error(f\"Benchmark failure for {name}: {e}\")\n y.append([name, emoji, round(file_size(filename), 1), None, None, None]) # mAP, t_inference\n\n # Print results\n check_yolo(device=device) # print system info\n df = pd.DataFrame(y, columns=[\"Format\", \"Status❔\", \"Size (MB)\", key, \"Inference time (ms/im)\", \"FPS\"])\n\n name = model.model_name\n dt = time.time() - t0\n legend = \"Benchmarks legend: - ✅ Success - ❎ Export passed but validation failed - ❌️ Export failed\"\n s = f\"\\nBenchmarks complete for {name} on {data} at imgsz={imgsz} ({dt:.2f}s)\\n{legend}\\n{df.fillna('-')}\\n\"\n LOGGER.info(s)\n with open(\"benchmarks.log\", \"a\", errors=\"ignore\", encoding=\"utf-8\") as f:\n f.write(s)\n\n if verbose and isinstance(verbose, float):\n metrics = df[key].array # values to compare to floor\n floor = verbose # minimum metric floor to pass, i.e. = 0.29 mAP for YOLOv5n\n assert all(x > floor for x in metrics if pd.notna(x)), f\"Benchmark failure: metric(s) < floor {floor}\"\n\n return df", "chunk_type": "function", "name": "benchmark", "file_path": "ultralytics\\ultralytics\\utils\\benchmarks.py", "start_line": 52, "end_line": 211, "start_col": 0, "end_col": 13, "parent_name": null, "docstring": "Benchmark a YOLO model across different formats for speed and accuracy.\n\nArgs:\n model (str | Path): Path to the model file or directory.\n data (str | None): Dataset to evaluate on, inherited from TASK2DATA if not passed.\n imgsz (int): Image size for the benchmark.\n half (bool): Use half-precision for the model if True.\n int8 (bool): Use int8-precision for the model if True.\n device (str): Device to run the benchmark on, either 'cpu' or 'cuda'.\n verbose (bool | float): If True or a float, assert benchmarks pass with given metric.\n eps (float): Epsilon value for divide by zero prevention.\n format (str): Export format for benchmarking. If not supplied all formats are benchmarked.\n **kwargs (Any): Additional keyword arguments for exporter.\n\nReturns:\n (pandas.DataFrame): A pandas DataFrame with benchmark results for each format, including file size, metric,\n and inference time.\n\nExamples:\n Benchmark a YOLO model with default settings:\n >>> from ultralytics.utils.benchmarks import benchmark\n >>> benchmark(model=\"yolo11n.pt\", imgsz=640)", "parameters": [ "model", "data", "imgsz", "half", "int8", "device", "verbose", "eps", "format" ], "return_type": null, "decorators": [], "complexity_score": 23, "dependencies": [ "glob", "os", "platform", "re", "shutil", "time", "pathlib.Path", "typing.List", "typing.Optional", "typing.Tuple", "typing.Union", "numpy", "torch.cuda", "ultralytics.YOLO", "ultralytics.YOLOWorld", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASK2METRIC", "ultralytics.engine.exporter.export_formats", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.IS_JETSON", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.MACOS", "ultralytics.utils.TQDM", "ultralytics.utils.WEIGHTS_DIR", "ultralytics.utils.YAML", "ultralytics.utils.checks.IS_PYTHON_3_13", "ultralytics.utils.checks.check_imgsz", "ultralytics.utils.checks.check_requirements", "ultralytics.utils.checks.check_yolo", "ultralytics.utils.checks.is_rockchip", "ultralytics.utils.downloads.safe_download", "ultralytics.utils.files.file_size", "ultralytics.utils.torch_utils.get_cpu_info", "ultralytics.utils.torch_utils.select_device", "pandas", "roboflow.Roboflow", "onnxruntime" ], "chunk_id": "function_benchmark_488ba16b" }, { "content": "class RF100Benchmark:\n \"\"\"\n Benchmark YOLO model performance across various formats for speed and accuracy.\n\n This class provides functionality to benchmark YOLO models on the RF100 dataset collection.\n\n Attributes:\n ds_names (List[str]): Names of datasets used for benchmarking.\n ds_cfg_list (List[Path]): List of paths to dataset configuration files.\n rf (Roboflow): Roboflow instance for accessing datasets.\n val_metrics (List[str]): Metrics used for validation.\n\n Methods:\n set_key: Set Roboflow API key for accessing datasets.\n parse_dataset: Parse dataset links and download datasets.\n fix_yaml: Fix train and validation paths in YAML files.\n evaluate: Evaluate model performance on validation results.\n \"\"\"\n\n def __init__(self):\n \"\"\"Initialize the RF100Benchmark class for benchmarking YOLO model performance across various formats.\"\"\"\n self.ds_names = []\n self.ds_cfg_list = []\n self.rf = None\n self.val_metrics = [\"class\", \"images\", \"targets\", \"precision\", \"recall\", \"map50\", \"map95\"]\n\n def set_key(self, api_key: str):\n \"\"\"\n Set Roboflow API key for processing.\n\n Args:\n api_key (str): The API key.\n\n Examples:\n Set the Roboflow API key for accessing datasets:\n >>> benchmark = RF100Benchmark()\n >>> benchmark.set_key(\"your_roboflow_api_key\")\n \"\"\"\n check_requirements(\"roboflow\")\n from roboflow import Roboflow\n\n self.rf = Roboflow(api_key=api_key)\n\n def parse_dataset(self, ds_link_txt: str = \"datasets_links.txt\"):\n \"\"\"\n Parse dataset links and download datasets.\n\n Args:\n ds_link_txt (str): Path to the file containing dataset links.\n\n Returns:\n ds_names (List[str]): List of dataset names.\n ds_cfg_list (List[Path]): List of paths to dataset configuration files.\n\n Examples:\n >>> benchmark = RF100Benchmark()\n >>> benchmark.set_key(\"api_key\")\n >>> benchmark.parse_dataset(\"datasets_links.txt\")\n \"\"\"\n (shutil.rmtree(\"rf-100\"), os.mkdir(\"rf-100\")) if os.path.exists(\"rf-100\") else os.mkdir(\"rf-100\")\n os.chdir(\"rf-100\")\n os.mkdir(\"ultralytics-benchmarks\")\n safe_download(\"https://github.com/ultralytics/assets/releases/download/v0.0.0/datasets_links.txt\")\n\n with open(ds_link_txt, encoding=\"utf-8\") as file:\n for line in file:\n try:\n _, url, workspace, project, version = re.split(\"/+\", line.strip())\n self.ds_names.append(project)\n proj_version = f\"{project}-{version}\"\n if not Path(proj_version).exists():\n self.rf.workspace(workspace).project(project).version(version).download(\"yolov8\")\n else:\n LOGGER.info(\"Dataset already downloaded.\")\n self.ds_cfg_list.append(Path.cwd() / proj_version / \"data.yaml\")\n except Exception:\n continue\n\n return self.ds_names, self.ds_cfg_list\n\n @staticmethod\n def fix_yaml(path: Path):\n \"\"\"Fix the train and validation paths in a given YAML file.\"\"\"\n yaml_data = YAML.load(path)\n yaml_data[\"train\"] = \"train/images\"\n yaml_data[\"val\"] = \"valid/images\"\n YAML.dump(yaml_data, path)\n\n def evaluate(self, yaml_path: str, val_log_file: str, eval_log_file: str, list_ind: int):\n \"\"\"\n Evaluate model performance on validation results.\n\n Args:\n yaml_path (str): Path to the YAML configuration file.\n val_log_file (str): Path to the validation log file.\n eval_log_file (str): Path to the evaluation log file.\n list_ind (int): Index of the current dataset in the list.\n\n Returns:\n (float): The mean average precision (mAP) value for the evaluated model.\n\n Examples:\n Evaluate a model on a specific dataset\n >>> benchmark = RF100Benchmark()\n >>> benchmark.evaluate(\"path/to/data.yaml\", \"path/to/val_log.txt\", \"path/to/eval_log.txt\", 0)\n \"\"\"\n skip_symbols = [\"🚀\", \"⚠️\", \"💡\", \"❌\"]\n class_names = YAML.load(yaml_path)[\"names\"]\n with open(val_log_file, encoding=\"utf-8\") as f:\n lines = f.readlines()\n eval_lines = []\n for line in lines:\n if any(symbol in line for symbol in skip_symbols):\n continue\n entries = line.split(\" \")\n entries = list(filter(lambda val: val != \"\", entries))\n entries = [e.strip(\"\\n\") for e in entries]\n eval_lines.extend(\n {\n \"class\": entries[0],\n \"images\": entries[1],\n \"targets\": entries[2],\n \"precision\": entries[3],\n \"recall\": entries[4],\n \"map50\": entries[5],\n \"map95\": entries[6],\n }\n for e in entries\n if e in class_names or (e == \"all\" and \"(AP)\" not in entries and \"(AR)\" not in entries)\n )\n map_val = 0.0\n if len(eval_lines) > 1:\n LOGGER.info(\"Multiple dicts found\")\n for lst in eval_lines:\n if lst[\"class\"] == \"all\":\n map_val = lst[\"map50\"]\n else:\n LOGGER.info(\"Single dict found\")\n map_val = [res[\"map50\"] for res in eval_lines][0]\n\n with open(eval_log_file, \"a\", encoding=\"utf-8\") as f:\n f.write(f\"{self.ds_names[list_ind]}: {map_val}\\n\")\n\n return float(map_val)", "chunk_type": "class", "name": "RF100Benchmark", "file_path": "ultralytics\\ultralytics\\utils\\benchmarks.py", "start_line": 214, "end_line": 357, "start_col": 0, "end_col": 29, "parent_name": null, "docstring": "Benchmark YOLO model performance across various formats for speed and accuracy.\n\nThis class provides functionality to benchmark YOLO models on the RF100 dataset collection.\n\nAttributes:\n ds_names (List[str]): Names of datasets used for benchmarking.\n ds_cfg_list (List[Path]): List of paths to dataset configuration files.\n rf (Roboflow): Roboflow instance for accessing datasets.\n val_metrics (List[str]): Metrics used for validation.\n\nMethods:\n set_key: Set Roboflow API key for accessing datasets.\n parse_dataset: Parse dataset links and download datasets.\n fix_yaml: Fix train and validation paths in YAML files.\n evaluate: Evaluate model performance on validation results.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "glob", "os", "platform", "re", "shutil", "time", "pathlib.Path", "typing.List", "typing.Optional", "typing.Tuple", "typing.Union", "numpy", "torch.cuda", "ultralytics.YOLO", "ultralytics.YOLOWorld", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASK2METRIC", "ultralytics.engine.exporter.export_formats", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.IS_JETSON", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.MACOS", "ultralytics.utils.TQDM", "ultralytics.utils.WEIGHTS_DIR", "ultralytics.utils.YAML", "ultralytics.utils.checks.IS_PYTHON_3_13", "ultralytics.utils.checks.check_imgsz", "ultralytics.utils.checks.check_requirements", "ultralytics.utils.checks.check_yolo", "ultralytics.utils.checks.is_rockchip", "ultralytics.utils.downloads.safe_download", "ultralytics.utils.files.file_size", "ultralytics.utils.torch_utils.get_cpu_info", "ultralytics.utils.torch_utils.select_device", "pandas", "roboflow.Roboflow", "onnxruntime" ], "chunk_id": "class_RF100Benchmark_a473c734" }, { "content": "class ProfileModels:\n \"\"\"\n ProfileModels class for profiling different models on ONNX and TensorRT.\n\n This class profiles the performance of different models, returning results such as model speed and FLOPs.\n\n Attributes:\n paths (List[str]): Paths of the models to profile.\n num_timed_runs (int): Number of timed runs for the profiling.\n num_warmup_runs (int): Number of warmup runs before profiling.\n min_time (float): Minimum number of seconds to profile for.\n imgsz (int): Image size used in the models.\n half (bool): Flag to indicate whether to use FP16 half-precision for TensorRT profiling.\n trt (bool): Flag to indicate whether to profile using TensorRT.\n device (torch.device): Device used for profiling.\n\n Methods:\n run: Profile YOLO models for speed and accuracy across various formats.\n get_files: Get all relevant model files.\n get_onnx_model_info: Extract metadata from an ONNX model.\n iterative_sigma_clipping: Apply sigma clipping to remove outliers.\n profile_tensorrt_model: Profile a TensorRT model.\n profile_onnx_model: Profile an ONNX model.\n generate_table_row: Generate a table row with model metrics.\n generate_results_dict: Generate a dictionary of profiling results.\n print_table: Print a formatted table of results.\n\n Examples:\n Profile models and print results\n >>> from ultralytics.utils.benchmarks import ProfileModels\n >>> profiler = ProfileModels([\"yolo11n.yaml\", \"yolov8s.yaml\"], imgsz=640)\n >>> profiler.run()\n \"\"\"\n\n def __init__(\n self,\n paths: List[str],\n num_timed_runs: int = 100,\n num_warmup_runs: int = 10,\n min_time: float = 60,\n imgsz: int = 640,\n half: bool = True,\n trt: bool = True,\n device: Optional[Union[torch.device, str]] = None,\n ):\n \"\"\"\n Initialize the ProfileModels class for profiling models.\n\n Args:\n paths (List[str]): List of paths of the models to be profiled.\n num_timed_runs (int): Number of timed runs for the profiling.\n num_warmup_runs (int): Number of warmup runs before the actual profiling starts.\n min_time (float): Minimum time in seconds for profiling a model.\n imgsz (int): Size of the image used during profiling.\n half (bool): Flag to indicate whether to use FP16 half-precision for TensorRT profiling.\n trt (bool): Flag to indicate whether to profile using TensorRT.\n device (torch.device | str | None): Device used for profiling. If None, it is determined automatically.\n\n Notes:\n FP16 'half' argument option removed for ONNX as slower on CPU than FP32.\n\n Examples:\n Initialize and profile models\n >>> from ultralytics.utils.benchmarks import ProfileModels\n >>> profiler = ProfileModels([\"yolo11n.yaml\", \"yolov8s.yaml\"], imgsz=640)\n >>> profiler.run()\n \"\"\"\n self.paths = paths\n self.num_timed_runs = num_timed_runs\n self.num_warmup_runs = num_warmup_runs\n self.min_time = min_time\n self.imgsz = imgsz\n self.half = half\n self.trt = trt # run TensorRT profiling\n self.device = device if isinstance(device, torch.device) else select_device(device)\n\n def run(self):\n \"\"\"\n Profile YOLO models for speed and accuracy across various formats including ONNX and TensorRT.\n\n Returns:\n (List[dict]): List of dictionaries containing profiling results for each model.\n\n Examples:\n Profile models and print results\n >>> from ultralytics.utils.benchmarks import ProfileModels\n >>> profiler = ProfileModels([\"yolo11n.yaml\", \"yolov8s.yaml\"])\n >>> results = profiler.run()\n \"\"\"\n files = self.get_files()\n\n if not files:\n LOGGER.warning(\"No matching *.pt or *.onnx files found.\")\n return []\n\n table_rows = []\n output = []\n for file in files:\n engine_file = file.with_suffix(\".engine\")\n if file.suffix in {\".pt\", \".yaml\", \".yml\"}:\n model = YOLO(str(file))\n model.fuse() # to report correct params and GFLOPs in model.info()\n model_info = model.info()\n if self.trt and self.device.type != \"cpu\" and not engine_file.is_file():\n engine_file = model.export(\n format=\"engine\",\n half=self.half,\n imgsz=self.imgsz,\n device=self.device,\n verbose=False,\n )\n onnx_file = model.export(\n format=\"onnx\",\n imgsz=self.imgsz,\n device=self.device,\n verbose=False,\n )\n elif file.suffix == \".onnx\":\n model_info = self.get_onnx_model_info(file)\n onnx_file = file\n else:\n continue\n\n t_engine = self.profile_tensorrt_model(str(engine_file))\n t_onnx = self.profile_onnx_model(str(onnx_file))\n table_rows.append(self.generate_table_row(file.stem, t_onnx, t_engine, model_info))\n output.append(self.generate_results_dict(file.stem, t_onnx, t_engine, model_info))\n\n self.print_table(table_rows)\n return output\n\n def get_files(self):\n \"\"\"\n Return a list of paths for all relevant model files given by the user.\n\n Returns:\n (List[Path]): List of Path objects for the model files.\n \"\"\"\n files = []\n for path in self.paths:\n path = Path(path)\n if path.is_dir():\n extensions = [\"*.pt\", \"*.onnx\", \"*.yaml\"]\n files.extend([file for ext in extensions for file in glob.glob(str(path / ext))])\n elif path.suffix in {\".pt\", \".yaml\", \".yml\"}: # add non-existing\n files.append(str(path))\n else:\n files.extend(glob.glob(str(path)))\n\n LOGGER.info(f\"Profiling: {sorted(files)}\")\n return [Path(file) for file in sorted(files)]\n\n @staticmethod\n def get_onnx_model_info(onnx_file: str):\n \"\"\"Extract metadata from an ONNX model file including parameters, GFLOPs, and input shape.\"\"\"\n return 0.0, 0.0, 0.0, 0.0 # return (num_layers, num_params, num_gradients, num_flops)\n\n @staticmethod\n def iterative_sigma_clipping(data: np.ndarray, sigma: float = 2, max_iters: int = 3):\n \"\"\"\n Apply iterative sigma clipping to data to remove outliers.\n\n Args:\n data (np.ndarray): Input data array.\n sigma (float): Number of standard deviations to use for clipping.\n max_iters (int): Maximum number of iterations for the clipping process.\n\n Returns:\n (np.ndarray): Clipped data array with outliers removed.\n \"\"\"\n data = np.array(data)\n for _ in range(max_iters):\n mean, std = np.mean(data), np.std(data)\n clipped_data = data[(data > mean - sigma * std) & (data < mean + sigma * std)]\n if len(clipped_data) == len(data):\n break\n data = clipped_data\n return data\n\n def profile_tensorrt_model(self, engine_file: str, eps: float = 1e-3):\n \"\"\"\n Profile YOLO model performance with TensorRT, measuring average run time and standard deviation.\n\n Args:\n engine_file (str): Path to the TensorRT engine file.\n eps (float): Small epsilon value to prevent division by zero.\n\n Returns:\n mean_time (float): Mean inference time in milliseconds.\n std_time (float): Standard deviation of inference time in milliseconds.\n \"\"\"\n if not self.trt or not Path(engine_file).is_file():\n return 0.0, 0.0\n\n # Model and input\n model = YOLO(engine_file)\n input_data = np.zeros((self.imgsz, self.imgsz, 3), dtype=np.uint8) # use uint8 for Classify\n\n # Warmup runs\n elapsed = 0.0\n for _ in range(3):\n start_time = time.time()\n for _ in range(self.num_warmup_runs):\n model(input_data, imgsz=self.imgsz, verbose=False)\n elapsed = time.time() - start_time\n\n # Compute number of runs as higher of min_time or num_timed_runs\n num_runs = max(round(self.min_time / (elapsed + eps) * self.num_warmup_runs), self.num_timed_runs * 50)\n\n # Timed runs\n run_times = []\n for _ in TQDM(range(num_runs), desc=engine_file):\n results = model(input_data, imgsz=self.imgsz, verbose=False)\n run_times.append(results[0].speed[\"inference\"]) # Convert to milliseconds\n\n run_times = self.iterative_sigma_clipping(np.array(run_times), sigma=2, max_iters=3) # sigma clipping\n return np.mean(run_times), np.std(run_times)\n\n def profile_onnx_model(self, onnx_file: str, eps: float = 1e-3):\n \"\"\"\n Profile an ONNX model, measuring average inference time and standard deviation across multiple runs.\n\n Args:\n onnx_file (str): Path to the ONNX model file.\n eps (float): Small epsilon value to prevent division by zero.\n\n Returns:\n mean_time (float): Mean inference time in milliseconds.\n std_time (float): Standard deviation of inference time in milliseconds.\n \"\"\"\n check_requirements(\"onnxruntime\")\n import onnxruntime as ort\n\n # Session with either 'TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider'\n sess_options = ort.SessionOptions()\n sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL\n sess_options.intra_op_num_threads = 8 # Limit the number of threads\n sess = ort.InferenceSession(onnx_file, sess_options, providers=[\"CPUExecutionProvider\"])\n\n input_tensor = sess.get_inputs()[0]\n input_type = input_tensor.type\n dynamic = not all(isinstance(dim, int) and dim >= 0 for dim in input_tensor.shape) # dynamic input shape\n input_shape = (1, 3, self.imgsz, self.imgsz) if dynamic else input_tensor.shape\n\n # Mapping ONNX datatype to numpy datatype\n if \"float16\" in input_type:\n input_dtype = np.float16\n elif \"float\" in input_type:\n input_dtype = np.float32\n elif \"double\" in input_type:\n input_dtype = np.float64\n elif \"int64\" in input_type:\n input_dtype = np.int64\n elif \"int32\" in input_type:\n input_dtype = np.int32\n else:\n raise ValueError(f\"Unsupported ONNX datatype {input_type}\")\n\n input_data = np.random.rand(*input_shape).astype(input_dtype)\n input_name = input_tensor.name\n output_name = sess.get_outputs()[0].name\n\n # Warmup runs\n elapsed = 0.0\n for _ in range(3):\n start_time = time.time()\n for _ in range(self.num_warmup_runs):\n sess.run([output_name], {input_name: input_data})\n elapsed = time.time() - start_time\n\n # Compute number of runs as higher of min_time or num_timed_runs\n num_runs = max(round(self.min_time / (elapsed + eps) * self.num_warmup_runs), self.num_timed_runs)\n\n # Timed runs\n run_times = []\n for _ in TQDM(range(num_runs), desc=onnx_file):\n start_time = time.time()\n sess.run([output_name], {input_name: input_data})\n run_times.append((time.time() - start_time) * 1000) # Convert to milliseconds\n\n run_times = self.iterative_sigma_clipping(np.array(run_times), sigma=2, max_iters=5) # sigma clipping\n return np.mean(run_times), np.std(run_times)\n\n def generate_table_row(\n self,\n model_name: str,\n t_onnx: Tuple[float, float],\n t_engine: Tuple[float, float],\n model_info: Tuple[float, float, float, float],\n ):\n \"\"\"\n Generate a table row string with model performance metrics.\n\n Args:\n model_name (str): Name of the model.\n t_onnx (tuple): ONNX model inference time statistics (mean, std).\n t_engine (tuple): TensorRT engine inference time statistics (mean, std).\n model_info (tuple): Model information (layers, params, gradients, flops).\n\n Returns:\n (str): Formatted table row string with model metrics.\n \"\"\"\n layers, params, gradients, flops = model_info\n return (\n f\"| {model_name:18s} | {self.imgsz} | - | {t_onnx[0]:.1f}±{t_onnx[1]:.1f} ms | {t_engine[0]:.1f}±\"\n f\"{t_engine[1]:.1f} ms | {params / 1e6:.1f} | {flops:.1f} |\"\n )\n\n @staticmethod\n def generate_results_dict(\n model_name: str,\n t_onnx: Tuple[float, float],\n t_engine: Tuple[float, float],\n model_info: Tuple[float, float, float, float],\n ):\n \"\"\"\n Generate a dictionary of profiling results.\n\n Args:\n model_name (str): Name of the model.\n t_onnx (tuple): ONNX model inference time statistics (mean, std).\n t_engine (tuple): TensorRT engine inference time statistics (mean, std).\n model_info (tuple): Model information (layers, params, gradients, flops).\n\n Returns:\n (dict): Dictionary containing profiling results.\n \"\"\"\n layers, params, gradients, flops = model_info\n return {\n \"model/name\": model_name,\n \"model/parameters\": params,\n \"model/GFLOPs\": round(flops, 3),\n \"model/speed_ONNX(ms)\": round(t_onnx[0], 3),\n \"model/speed_TensorRT(ms)\": round(t_engine[0], 3),\n }\n\n @staticmethod\n def print_table(table_rows: List[str]):\n \"\"\"\n Print a formatted table of model profiling results.\n\n Args:\n table_rows (List[str]): List of formatted table row strings.\n \"\"\"\n gpu = torch.cuda.get_device_name(0) if torch.cuda.is_available() else \"GPU\"\n headers = [\n \"Model\",\n \"size
(pixels)\",\n \"mAPval
50-95\",\n f\"Speed
CPU ({get_cpu_info()}) ONNX
(ms)\",\n f\"Speed
{gpu} TensorRT
(ms)\",\n \"params
(M)\",\n \"FLOPs
(B)\",\n ]\n header = \"|\" + \"|\".join(f\" {h} \" for h in headers) + \"|\"\n separator = \"|\" + \"|\".join(\"-\" * (len(h) + 2) for h in headers) + \"|\"\n\n LOGGER.info(f\"\\n\\n{header}\")\n LOGGER.info(separator)\n for row in table_rows:\n LOGGER.info(row)", "chunk_type": "class", "name": "ProfileModels", "file_path": "ultralytics\\ultralytics\\utils\\benchmarks.py", "start_line": 360, "end_line": 720, "start_col": 0, "end_col": 28, "parent_name": null, "docstring": "ProfileModels class for profiling different models on ONNX and TensorRT.\n\nThis class profiles the performance of different models, returning results such as model speed and FLOPs.\n\nAttributes:\n paths (List[str]): Paths of the models to profile.\n num_timed_runs (int): Number of timed runs for the profiling.\n num_warmup_runs (int): Number of warmup runs before profiling.\n min_time (float): Minimum number of seconds to profile for.\n imgsz (int): Image size used in the models.\n half (bool): Flag to indicate whether to use FP16 half-precision for TensorRT profiling.\n trt (bool): Flag to indicate whether to profile using TensorRT.\n device (torch.device): Device used for profiling.\n\nMethods:\n run: Profile YOLO models for speed and accuracy across various formats.\n get_files: Get all relevant model files.\n get_onnx_model_info: Extract metadata from an ONNX model.\n iterative_sigma_clipping: Apply sigma clipping to remove outliers.\n profile_tensorrt_model: Profile a TensorRT model.\n profile_onnx_model: Profile an ONNX model.\n generate_table_row: Generate a table row with model metrics.\n generate_results_dict: Generate a dictionary of profiling results.\n print_table: Print a formatted table of results.\n\nExamples:\n Profile models and print results\n >>> from ultralytics.utils.benchmarks import ProfileModels\n >>> profiler = ProfileModels([\"yolo11n.yaml\", \"yolov8s.yaml\"], imgsz=640)\n >>> profiler.run()", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "glob", "os", "platform", "re", "shutil", "time", "pathlib.Path", "typing.List", "typing.Optional", "typing.Tuple", "typing.Union", "numpy", "torch.cuda", "ultralytics.YOLO", "ultralytics.YOLOWorld", "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASK2METRIC", "ultralytics.engine.exporter.export_formats", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.IS_JETSON", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.MACOS", "ultralytics.utils.TQDM", "ultralytics.utils.WEIGHTS_DIR", "ultralytics.utils.YAML", "ultralytics.utils.checks.IS_PYTHON_3_13", "ultralytics.utils.checks.check_imgsz", "ultralytics.utils.checks.check_requirements", "ultralytics.utils.checks.check_yolo", "ultralytics.utils.checks.is_rockchip", "ultralytics.utils.downloads.safe_download", "ultralytics.utils.files.file_size", "ultralytics.utils.torch_utils.get_cpu_info", "ultralytics.utils.torch_utils.select_device", "pandas", "roboflow.Roboflow", "onnxruntime" ], "chunk_id": "class_ProfileModels_c4eb9a5b" }, { "content": "import functools", "chunk_type": "import", "name": "functools", "file_path": "ultralytics\\ultralytics\\utils\\checks.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 16, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_functools_6378b944" }, { "content": "import glob", "chunk_type": "import", "name": "glob", "file_path": "ultralytics\\ultralytics\\utils\\checks.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 11, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_glob_83cd0773" }, { "content": "import inspect", "chunk_type": "import", "name": "inspect", "file_path": "ultralytics\\ultralytics\\utils\\checks.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 14, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_inspect_f54848db" }, { "content": "import math", "chunk_type": "import", "name": "math", "file_path": "ultralytics\\ultralytics\\utils\\checks.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 11, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_math_0f8e4529" }, { "content": "import os", "chunk_type": "import", "name": "os", "file_path": "ultralytics\\ultralytics\\utils\\checks.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 9, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_os_414f6710" }, { "content": "import platform", "chunk_type": "import", "name": "platform", "file_path": "ultralytics\\ultralytics\\utils\\checks.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 15, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_platform_a1f5096d" }, { "content": "import re", "chunk_type": "import", "name": "re", "file_path": "ultralytics\\ultralytics\\utils\\checks.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 9, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_re_b84d5a11" }, { "content": "import shutil", "chunk_type": "import", "name": "shutil", "file_path": "ultralytics\\ultralytics\\utils\\checks.py", "start_line": 10, "end_line": 10, "start_col": 0, "end_col": 13, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_shutil_d7f542ca" }, { "content": "import subprocess", "chunk_type": "import", "name": "subprocess", "file_path": "ultralytics\\ultralytics\\utils\\checks.py", "start_line": 11, "end_line": 11, "start_col": 0, "end_col": 17, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_subprocess_21ebf160" }, { "content": "import time", "chunk_type": "import", "name": "time", "file_path": "ultralytics\\ultralytics\\utils\\checks.py", "start_line": 12, "end_line": 12, "start_col": 0, "end_col": 11, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_time_a6472836" }, { "content": "from importlib import metadata", "chunk_type": "import", "name": "metadata", "file_path": "ultralytics\\ultralytics\\utils\\checks.py", "start_line": 13, "end_line": 13, "start_col": 0, "end_col": 30, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_metadata_f5ca7ec6" }, { "content": "from pathlib import Path", "chunk_type": "import", "name": "Path", "file_path": "ultralytics\\ultralytics\\utils\\checks.py", "start_line": 14, "end_line": 14, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Path_0f6f55a6" }, { "content": "from types import SimpleNamespace", "chunk_type": "import", "name": "SimpleNamespace", "file_path": "ultralytics\\ultralytics\\utils\\checks.py", "start_line": 15, "end_line": 15, "start_col": 0, "end_col": 33, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_SimpleNamespace_1f99f4c4" }, { "content": "from typing import Optional", "chunk_type": "import", "name": "Optional", "file_path": "ultralytics\\ultralytics\\utils\\checks.py", "start_line": 16, "end_line": 16, "start_col": 0, "end_col": 27, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Optional_024c48a9" }, { "content": "import cv2", "chunk_type": "import", "name": "cv2", "file_path": "ultralytics\\ultralytics\\utils\\checks.py", "start_line": 18, "end_line": 18, "start_col": 0, "end_col": 10, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_cv2_a2b45917" }, { "content": "import numpy as np", "chunk_type": "import", "name": "numpy", "file_path": "ultralytics\\ultralytics\\utils\\checks.py", "start_line": 19, "end_line": 19, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_numpy_47f34676" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\ultralytics\\utils\\checks.py", "start_line": 20, "end_line": 20, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_133e9c35" }, { "content": "from ultralytics.utils import (\n ARM64,\n ASSETS,\n AUTOINSTALL,\n IS_COLAB,\n IS_GIT_DIR,\n IS_JETSON,\n IS_KAGGLE,\n IS_PIP_PACKAGE,\n LINUX,\n LOGGER,\n MACOS,\n ONLINE,\n PYTHON_VERSION,\n RKNN_CHIPS,\n ROOT,\n TORCHVISION_VERSION,\n USER_CONFIG_DIR,\n WINDOWS,\n Retry,\n ThreadingLocked,\n TryExcept,\n clean_url,\n colorstr,\n downloads,\n is_github_action_running,\n url2file,\n)", "chunk_type": "import", "name": "ARM64, ASSETS, AUTOINSTALL, IS_COLAB, IS_GIT_DIR, IS_JETSON, IS_KAGGLE, IS_PIP_PACKAGE, LINUX, LOGGER, MACOS, ONLINE, PYTHON_VERSION, RKNN_CHIPS, ROOT, TORCHVISION_VERSION, USER_CONFIG_DIR, WINDOWS, Retry, ThreadingLocked, TryExcept, clean_url, colorstr, downloads, is_github_action_running, url2file", "file_path": "ultralytics\\ultralytics\\utils\\checks.py", "start_line": 22, "end_line": 49, "start_col": 0, "end_col": 1, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_ARM64, ASSETS, AUTOINSTALL, IS_COLAB, IS_GIT_DIR, IS_JETSON, IS_KAGGLE, IS_PIP_PACKAGE, LINUX, LOGGER, MACOS, ONLINE, PYTHON_VERSION, RKNN_CHIPS, ROOT, TORCHVISION_VERSION, USER_CONFIG_DIR, WINDOWS, Retry, ThreadingLocked, TryExcept, clean_url, colorstr, downloads, is_github_action_running, url2file_8c9fe9a2" }, { "content": "def parse_requirements(file_path=ROOT.parent / \"requirements.txt\", package=\"\"):\n \"\"\"\n Parse a requirements.txt file, ignoring lines that start with '#' and any text after '#'.\n\n Args:\n file_path (Path): Path to the requirements.txt file.\n package (str, optional): Python package to use instead of requirements.txt file.\n\n Returns:\n requirements (List[SimpleNamespace]): List of parsed requirements as SimpleNamespace objects with `name` and\n `specifier` attributes.\n\n Examples:\n >>> from ultralytics.utils.checks import parse_requirements\n >>> parse_requirements(package=\"ultralytics\")\n \"\"\"\n if package:\n requires = [x for x in metadata.distribution(package).requires if \"extra == \" not in x]\n else:\n requires = Path(file_path).read_text().splitlines()\n\n requirements = []\n for line in requires:\n line = line.strip()\n if line and not line.startswith(\"#\"):\n line = line.partition(\"#\")[0].strip() # ignore inline comments\n if match := re.match(r\"([a-zA-Z0-9-_]+)\\s*([<>!=~]+.*)?\", line):\n requirements.append(SimpleNamespace(name=match[1], specifier=match[2].strip() if match[2] else \"\"))\n\n return requirements", "chunk_type": "function", "name": "parse_requirements", "file_path": "ultralytics\\ultralytics\\utils\\checks.py", "start_line": 52, "end_line": 81, "start_col": 0, "end_col": 23, "parent_name": null, "docstring": "Parse a requirements.txt file, ignoring lines that start with '#' and any text after '#'.\n\nArgs:\n file_path (Path): Path to the requirements.txt file.\n package (str, optional): Python package to use instead of requirements.txt file.\n\nReturns:\n requirements (List[SimpleNamespace]): List of parsed requirements as SimpleNamespace objects with `name` and\n `specifier` attributes.\n\nExamples:\n >>> from ultralytics.utils.checks import parse_requirements\n >>> parse_requirements(package=\"ultralytics\")", "parameters": [ "file_path", "package" ], "return_type": null, "decorators": [], "complexity_score": 6, "dependencies": [ "functools", "glob", "inspect", "math", "os", "platform", "re", "shutil", "subprocess", "time", "importlib.metadata", "pathlib.Path", "types.SimpleNamespace", "typing.Optional", "cv2", "numpy", "torch", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.AUTOINSTALL", "ultralytics.utils.IS_COLAB", "ultralytics.utils.IS_GIT_DIR", "ultralytics.utils.IS_JETSON", "ultralytics.utils.IS_KAGGLE", "ultralytics.utils.IS_PIP_PACKAGE", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.MACOS", "ultralytics.utils.ONLINE", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.RKNN_CHIPS", "ultralytics.utils.ROOT", "ultralytics.utils.TORCHVISION_VERSION", "ultralytics.utils.USER_CONFIG_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.Retry", "ultralytics.utils.ThreadingLocked", "ultralytics.utils.TryExcept", "ultralytics.utils.clean_url", "ultralytics.utils.colorstr", "ultralytics.utils.downloads", "ultralytics.utils.is_github_action_running", "ultralytics.utils.url2file", "requests", "matplotlib.font_manager", "psutil", "ultralytics.utils.torch_utils.select_device", "psutil", "ultralytics.utils.ENVIRONMENT", "ultralytics.utils.torch_utils.get_cpu_info", "ultralytics.utils.torch_utils.get_gpu_info", "ultralytics.utils.torch_utils.autocast", "ultralytics.utils.torch_utils.get_cpu_info", "ultralytics.YOLO", "ultralytics.__version__", "IPython.display" ], "chunk_id": "function_parse_requirements_dc571047" }, { "content": "def parse_version(version=\"0.0.0\") -> tuple:\n \"\"\"\n Convert a version string to a tuple of integers, ignoring any extra non-numeric string attached to the version.\n\n Args:\n version (str): Version string, i.e. '2.0.1+cpu'\n\n Returns:\n (tuple): Tuple of integers representing the numeric part of the version, i.e. (2, 0, 1)\n \"\"\"\n try:\n return tuple(map(int, re.findall(r\"\\d+\", version)[:3])) # '2.0.1+cpu' -> (2, 0, 1)\n except Exception as e:\n LOGGER.warning(f\"failure for parse_version({version}), returning (0, 0, 0): {e}\")\n return 0, 0, 0", "chunk_type": "function", "name": "parse_version", "file_path": "ultralytics\\ultralytics\\utils\\checks.py", "start_line": 85, "end_line": 99, "start_col": 0, "end_col": 22, "parent_name": null, "docstring": "Convert a version string to a tuple of integers, ignoring any extra non-numeric string attached to the version.\n\nArgs:\n version (str): Version string, i.e. '2.0.1+cpu'\n\nReturns:\n (tuple): Tuple of integers representing the numeric part of the version, i.e. (2, 0, 1)", "parameters": [ "version" ], "return_type": "tuple", "decorators": [ "functools.lru_cache" ], "complexity_score": 2, "dependencies": [ "functools", "glob", "inspect", "math", "os", "platform", "re", "shutil", "subprocess", "time", "importlib.metadata", "pathlib.Path", "types.SimpleNamespace", "typing.Optional", "cv2", "numpy", "torch", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.AUTOINSTALL", "ultralytics.utils.IS_COLAB", "ultralytics.utils.IS_GIT_DIR", "ultralytics.utils.IS_JETSON", "ultralytics.utils.IS_KAGGLE", "ultralytics.utils.IS_PIP_PACKAGE", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.MACOS", "ultralytics.utils.ONLINE", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.RKNN_CHIPS", "ultralytics.utils.ROOT", "ultralytics.utils.TORCHVISION_VERSION", "ultralytics.utils.USER_CONFIG_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.Retry", "ultralytics.utils.ThreadingLocked", "ultralytics.utils.TryExcept", "ultralytics.utils.clean_url", "ultralytics.utils.colorstr", "ultralytics.utils.downloads", "ultralytics.utils.is_github_action_running", "ultralytics.utils.url2file", "requests", "matplotlib.font_manager", "psutil", "ultralytics.utils.torch_utils.select_device", "psutil", "ultralytics.utils.ENVIRONMENT", "ultralytics.utils.torch_utils.get_cpu_info", "ultralytics.utils.torch_utils.get_gpu_info", "ultralytics.utils.torch_utils.autocast", "ultralytics.utils.torch_utils.get_cpu_info", "ultralytics.YOLO", "ultralytics.__version__", "IPython.display" ], "chunk_id": "function_parse_version_06c0fcd5" }, { "content": "def is_ascii(s) -> bool:\n \"\"\"\n Check if a string is composed of only ASCII characters.\n\n Args:\n s (str | list | tuple | dict): Input to be checked (all are converted to string for checking).\n\n Returns:\n (bool): True if the string is composed only of ASCII characters, False otherwise.\n \"\"\"\n return all(ord(c) < 128 for c in str(s))", "chunk_type": "function", "name": "is_ascii", "file_path": "ultralytics\\ultralytics\\utils\\checks.py", "start_line": 102, "end_line": 112, "start_col": 0, "end_col": 44, "parent_name": null, "docstring": "Check if a string is composed of only ASCII characters.\n\nArgs:\n s (str | list | tuple | dict): Input to be checked (all are converted to string for checking).\n\nReturns:\n (bool): True if the string is composed only of ASCII characters, False otherwise.", "parameters": [ "s" ], "return_type": "bool", "decorators": [], "complexity_score": 2, "dependencies": [ "functools", "glob", "inspect", "math", "os", "platform", "re", "shutil", "subprocess", "time", "importlib.metadata", "pathlib.Path", "types.SimpleNamespace", "typing.Optional", "cv2", "numpy", "torch", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.AUTOINSTALL", "ultralytics.utils.IS_COLAB", "ultralytics.utils.IS_GIT_DIR", "ultralytics.utils.IS_JETSON", "ultralytics.utils.IS_KAGGLE", "ultralytics.utils.IS_PIP_PACKAGE", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.MACOS", "ultralytics.utils.ONLINE", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.RKNN_CHIPS", "ultralytics.utils.ROOT", "ultralytics.utils.TORCHVISION_VERSION", "ultralytics.utils.USER_CONFIG_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.Retry", "ultralytics.utils.ThreadingLocked", "ultralytics.utils.TryExcept", "ultralytics.utils.clean_url", "ultralytics.utils.colorstr", "ultralytics.utils.downloads", "ultralytics.utils.is_github_action_running", "ultralytics.utils.url2file", "requests", "matplotlib.font_manager", "psutil", "ultralytics.utils.torch_utils.select_device", "psutil", "ultralytics.utils.ENVIRONMENT", "ultralytics.utils.torch_utils.get_cpu_info", "ultralytics.utils.torch_utils.get_gpu_info", "ultralytics.utils.torch_utils.autocast", "ultralytics.utils.torch_utils.get_cpu_info", "ultralytics.YOLO", "ultralytics.__version__", "IPython.display" ], "chunk_id": "function_is_ascii_78a4aa5d" }, { "content": "def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0):\n \"\"\"\n Verify image size is a multiple of the given stride in each dimension. If the image size is not a multiple of the\n stride, update it to the nearest multiple of the stride that is greater than or equal to the given floor value.\n\n Args:\n imgsz (int | List[int]): Image size.\n stride (int): Stride value.\n min_dim (int): Minimum number of dimensions.\n max_dim (int): Maximum number of dimensions.\n floor (int): Minimum allowed value for image size.\n\n Returns:\n (List[int] | int): Updated image size.\n \"\"\"\n # Convert stride to integer if it is a tensor\n stride = int(stride.max() if isinstance(stride, torch.Tensor) else stride)\n\n # Convert image size to list if it is an integer\n if isinstance(imgsz, int):\n imgsz = [imgsz]\n elif isinstance(imgsz, (list, tuple)):\n imgsz = list(imgsz)\n elif isinstance(imgsz, str): # i.e. '640' or '[640,640]'\n imgsz = [int(imgsz)] if imgsz.isnumeric() else eval(imgsz)\n else:\n raise TypeError(\n f\"'imgsz={imgsz}' is of invalid type {type(imgsz).__name__}. \"\n f\"Valid imgsz types are int i.e. 'imgsz=640' or list i.e. 'imgsz=[640,640]'\"\n )\n\n # Apply max_dim\n if len(imgsz) > max_dim:\n msg = (\n \"'train' and 'val' imgsz must be an integer, while 'predict' and 'export' imgsz may be a [h, w] list \"\n \"or an integer, i.e. 'yolo export imgsz=640,480' or 'yolo export imgsz=640'\"\n )\n if max_dim != 1:\n raise ValueError(f\"imgsz={imgsz} is not a valid image size. {msg}\")\n LOGGER.warning(f\"updating to 'imgsz={max(imgsz)}'. {msg}\")\n imgsz = [max(imgsz)]\n # Make image size a multiple of the stride\n sz = [max(math.ceil(x / stride) * stride, floor) for x in imgsz]\n\n # Print warning message if image size was updated\n if sz != imgsz:\n LOGGER.warning(f\"imgsz={imgsz} must be multiple of max stride {stride}, updating to {sz}\")\n\n # Add missing dimensions if necessary\n sz = [sz[0], sz[0]] if min_dim == 2 and len(sz) == 1 else sz[0] if min_dim == 1 and len(sz) == 1 else sz\n\n return sz", "chunk_type": "function", "name": "check_imgsz", "file_path": "ultralytics\\ultralytics\\utils\\checks.py", "start_line": 115, "end_line": 166, "start_col": 0, "end_col": 13, "parent_name": null, "docstring": "Verify image size is a multiple of the given stride in each dimension. If the image size is not a multiple of the\nstride, update it to the nearest multiple of the stride that is greater than or equal to the given floor value.\n\nArgs:\n imgsz (int | List[int]): Image size.\n stride (int): Stride value.\n min_dim (int): Minimum number of dimensions.\n max_dim (int): Maximum number of dimensions.\n floor (int): Minimum allowed value for image size.\n\nReturns:\n (List[int] | int): Updated image size.", "parameters": [ "imgsz", "stride", "min_dim", "max_dim", "floor" ], "return_type": null, "decorators": [], "complexity_score": 8, "dependencies": [ "functools", "glob", "inspect", "math", "os", "platform", "re", "shutil", "subprocess", "time", "importlib.metadata", "pathlib.Path", "types.SimpleNamespace", "typing.Optional", "cv2", "numpy", "torch", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.AUTOINSTALL", "ultralytics.utils.IS_COLAB", "ultralytics.utils.IS_GIT_DIR", "ultralytics.utils.IS_JETSON", "ultralytics.utils.IS_KAGGLE", "ultralytics.utils.IS_PIP_PACKAGE", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.MACOS", "ultralytics.utils.ONLINE", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.RKNN_CHIPS", "ultralytics.utils.ROOT", "ultralytics.utils.TORCHVISION_VERSION", "ultralytics.utils.USER_CONFIG_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.Retry", "ultralytics.utils.ThreadingLocked", "ultralytics.utils.TryExcept", "ultralytics.utils.clean_url", "ultralytics.utils.colorstr", "ultralytics.utils.downloads", "ultralytics.utils.is_github_action_running", "ultralytics.utils.url2file", "requests", "matplotlib.font_manager", "psutil", "ultralytics.utils.torch_utils.select_device", "psutil", "ultralytics.utils.ENVIRONMENT", "ultralytics.utils.torch_utils.get_cpu_info", "ultralytics.utils.torch_utils.get_gpu_info", "ultralytics.utils.torch_utils.autocast", "ultralytics.utils.torch_utils.get_cpu_info", "ultralytics.YOLO", "ultralytics.__version__", "IPython.display" ], "chunk_id": "function_check_imgsz_d00cc53c" }, { "content": "def check_uv():\n \"\"\"Check if uv package manager is installed and can run successfully.\"\"\"\n try:\n return subprocess.run([\"uv\", \"-V\"], capture_output=True).returncode == 0\n except FileNotFoundError:\n return False", "chunk_type": "function", "name": "check_uv", "file_path": "ultralytics\\ultralytics\\utils\\checks.py", "start_line": 170, "end_line": 175, "start_col": 0, "end_col": 20, "parent_name": null, "docstring": "Check if uv package manager is installed and can run successfully.", "parameters": [], "return_type": null, "decorators": [ "functools.lru_cache" ], "complexity_score": 2, "dependencies": [ "functools", "glob", "inspect", "math", "os", "platform", "re", "shutil", "subprocess", "time", "importlib.metadata", "pathlib.Path", "types.SimpleNamespace", "typing.Optional", "cv2", "numpy", "torch", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.AUTOINSTALL", "ultralytics.utils.IS_COLAB", "ultralytics.utils.IS_GIT_DIR", "ultralytics.utils.IS_JETSON", "ultralytics.utils.IS_KAGGLE", "ultralytics.utils.IS_PIP_PACKAGE", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.MACOS", "ultralytics.utils.ONLINE", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.RKNN_CHIPS", "ultralytics.utils.ROOT", "ultralytics.utils.TORCHVISION_VERSION", "ultralytics.utils.USER_CONFIG_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.Retry", "ultralytics.utils.ThreadingLocked", "ultralytics.utils.TryExcept", "ultralytics.utils.clean_url", "ultralytics.utils.colorstr", "ultralytics.utils.downloads", "ultralytics.utils.is_github_action_running", "ultralytics.utils.url2file", "requests", "matplotlib.font_manager", "psutil", "ultralytics.utils.torch_utils.select_device", "psutil", "ultralytics.utils.ENVIRONMENT", "ultralytics.utils.torch_utils.get_cpu_info", "ultralytics.utils.torch_utils.get_gpu_info", "ultralytics.utils.torch_utils.autocast", "ultralytics.utils.torch_utils.get_cpu_info", "ultralytics.YOLO", "ultralytics.__version__", "IPython.display" ], "chunk_id": "function_check_uv_b28ad8d6" }, { "content": "def check_version(\n current: str = \"0.0.0\",\n required: str = \"0.0.0\",\n name: str = \"version\",\n hard: bool = False,\n verbose: bool = False,\n msg: str = \"\",\n) -> bool:\n \"\"\"\n Check current version against the required version or range.\n\n Args:\n current (str): Current version or package name to get version from.\n required (str): Required version or range (in pip-style format).\n name (str): Name to be used in warning message.\n hard (bool): If True, raise an AssertionError if the requirement is not met.\n verbose (bool): If True, print warning message if requirement is not met.\n msg (str): Extra message to display if verbose.\n\n Returns:\n (bool): True if requirement is met, False otherwise.\n\n Examples:\n Check if current version is exactly 22.04\n >>> check_version(current=\"22.04\", required=\"==22.04\")\n\n Check if current version is greater than or equal to 22.04\n >>> check_version(current=\"22.10\", required=\"22.04\") # assumes '>=' inequality if none passed\n\n Check if current version is less than or equal to 22.04\n >>> check_version(current=\"22.04\", required=\"<=22.04\")\n\n Check if current version is between 20.04 (inclusive) and 22.04 (exclusive)\n >>> check_version(current=\"21.10\", required=\">20.04,<22.04\")\n \"\"\"\n if not current: # if current is '' or None\n LOGGER.warning(f\"invalid check_version({current}, {required}) requested, please check values.\")\n return True\n elif not current[0].isdigit(): # current is package name rather than version string, i.e. current='ultralytics'\n try:\n name = current # assigned package name to 'name' arg\n current = metadata.version(current) # get version string from package name\n except metadata.PackageNotFoundError as e:\n if hard:\n raise ModuleNotFoundError(f\"{current} package is required but not installed\") from e\n else:\n return False\n\n if not required: # if required is '' or None\n return True\n\n if \"sys_platform\" in required and ( # i.e. required='<2.4.0,>=1.8.0; sys_platform == \"win32\"'\n (WINDOWS and \"win32\" not in required)\n or (LINUX and \"linux\" not in required)\n or (MACOS and \"macos\" not in required and \"darwin\" not in required)\n ):\n return True\n\n op = \"\"\n version = \"\"\n result = True\n c = parse_version(current) # '1.2.3' -> (1, 2, 3)\n for r in required.strip(\",\").split(\",\"):\n op, version = re.match(r\"([^0-9]*)([\\d.]+)\", r).groups() # split '>=22.04' -> ('>=', '22.04')\n if not op:\n op = \">=\" # assume >= if no op passed\n v = parse_version(version) # '1.2.3' -> (1, 2, 3)\n if op == \"==\" and c != v:\n result = False\n elif op == \"!=\" and c == v:\n result = False\n elif op == \">=\" and not (c >= v):\n result = False\n elif op == \"<=\" and not (c <= v):\n result = False\n elif op == \">\" and not (c > v):\n result = False\n elif op == \"<\" and not (c < v):\n result = False\n if not result:\n warning = f\"{name}{required} is required, but {name}=={current} is currently installed {msg}\"\n if hard:\n raise ModuleNotFoundError(warning) # assert version requirements met\n if verbose:\n LOGGER.warning(warning)\n return result", "chunk_type": "function", "name": "check_version", "file_path": "ultralytics\\ultralytics\\utils\\checks.py", "start_line": 179, "end_line": 264, "start_col": 0, "end_col": 17, "parent_name": null, "docstring": "Check current version against the required version or range.\n\nArgs:\n current (str): Current version or package name to get version from.\n required (str): Required version or range (in pip-style format).\n name (str): Name to be used in warning message.\n hard (bool): If True, raise an AssertionError if the requirement is not met.\n verbose (bool): If True, print warning message if requirement is not met.\n msg (str): Extra message to display if verbose.\n\nReturns:\n (bool): True if requirement is met, False otherwise.\n\nExamples:\n Check if current version is exactly 22.04\n >>> check_version(current=\"22.04\", required=\"==22.04\")\n\n Check if current version is greater than or equal to 22.04\n >>> check_version(current=\"22.10\", required=\"22.04\") # assumes '>=' inequality if none passed\n\n Check if current version is less than or equal to 22.04\n >>> check_version(current=\"22.04\", required=\"<=22.04\")\n\n Check if current version is between 20.04 (inclusive) and 22.04 (exclusive)\n >>> check_version(current=\"21.10\", required=\">20.04,<22.04\")", "parameters": [ "current: str", "required: str", "name: str", "hard: bool", "verbose: bool", "msg: str" ], "return_type": "bool", "decorators": [ "functools.lru_cache" ], "complexity_score": 18, "dependencies": [ "functools", "glob", "inspect", "math", "os", "platform", "re", "shutil", "subprocess", "time", "importlib.metadata", "pathlib.Path", "types.SimpleNamespace", "typing.Optional", "cv2", "numpy", "torch", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.AUTOINSTALL", "ultralytics.utils.IS_COLAB", "ultralytics.utils.IS_GIT_DIR", "ultralytics.utils.IS_JETSON", "ultralytics.utils.IS_KAGGLE", "ultralytics.utils.IS_PIP_PACKAGE", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.MACOS", "ultralytics.utils.ONLINE", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.RKNN_CHIPS", "ultralytics.utils.ROOT", "ultralytics.utils.TORCHVISION_VERSION", "ultralytics.utils.USER_CONFIG_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.Retry", "ultralytics.utils.ThreadingLocked", "ultralytics.utils.TryExcept", "ultralytics.utils.clean_url", "ultralytics.utils.colorstr", "ultralytics.utils.downloads", "ultralytics.utils.is_github_action_running", "ultralytics.utils.url2file", "requests", "matplotlib.font_manager", "psutil", "ultralytics.utils.torch_utils.select_device", "psutil", "ultralytics.utils.ENVIRONMENT", "ultralytics.utils.torch_utils.get_cpu_info", "ultralytics.utils.torch_utils.get_gpu_info", "ultralytics.utils.torch_utils.autocast", "ultralytics.utils.torch_utils.get_cpu_info", "ultralytics.YOLO", "ultralytics.__version__", "IPython.display" ], "chunk_id": "function_check_version_ae0b5792" }, { "content": "def check_latest_pypi_version(package_name=\"ultralytics\"):\n \"\"\"\n Return the latest version of a PyPI package without downloading or installing it.\n\n Args:\n package_name (str): The name of the package to find the latest version for.\n\n Returns:\n (str): The latest version of the package.\n \"\"\"\n import requests # slow import\n\n try:\n requests.packages.urllib3.disable_warnings() # Disable the InsecureRequestWarning\n response = requests.get(f\"https://pypi.org/pypi/{package_name}/json\", timeout=3)\n if response.status_code == 200:\n return response.json()[\"info\"][\"version\"]\n except Exception:\n return None", "chunk_type": "function", "name": "check_latest_pypi_version", "file_path": "ultralytics\\ultralytics\\utils\\checks.py", "start_line": 267, "end_line": 285, "start_col": 0, "end_col": 19, "parent_name": null, "docstring": "Return the latest version of a PyPI package without downloading or installing it.\n\nArgs:\n package_name (str): The name of the package to find the latest version for.\n\nReturns:\n (str): The latest version of the package.", "parameters": [ "package_name" ], "return_type": null, "decorators": [], "complexity_score": 3, "dependencies": [ "functools", "glob", "inspect", "math", "os", "platform", "re", "shutil", "subprocess", "time", "importlib.metadata", "pathlib.Path", "types.SimpleNamespace", "typing.Optional", "cv2", "numpy", "torch", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.AUTOINSTALL", "ultralytics.utils.IS_COLAB", "ultralytics.utils.IS_GIT_DIR", "ultralytics.utils.IS_JETSON", "ultralytics.utils.IS_KAGGLE", "ultralytics.utils.IS_PIP_PACKAGE", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.MACOS", "ultralytics.utils.ONLINE", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.RKNN_CHIPS", "ultralytics.utils.ROOT", "ultralytics.utils.TORCHVISION_VERSION", "ultralytics.utils.USER_CONFIG_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.Retry", "ultralytics.utils.ThreadingLocked", "ultralytics.utils.TryExcept", "ultralytics.utils.clean_url", "ultralytics.utils.colorstr", "ultralytics.utils.downloads", "ultralytics.utils.is_github_action_running", "ultralytics.utils.url2file", "requests", "matplotlib.font_manager", "psutil", "ultralytics.utils.torch_utils.select_device", "psutil", "ultralytics.utils.ENVIRONMENT", "ultralytics.utils.torch_utils.get_cpu_info", "ultralytics.utils.torch_utils.get_gpu_info", "ultralytics.utils.torch_utils.autocast", "ultralytics.utils.torch_utils.get_cpu_info", "ultralytics.YOLO", "ultralytics.__version__", "IPython.display" ], "chunk_id": "function_check_latest_pypi_version_58a0c5a5" }, { "content": "def check_pip_update_available():\n \"\"\"\n Check if a new version of the ultralytics package is available on PyPI.\n\n Returns:\n (bool): True if an update is available, False otherwise.\n \"\"\"\n if ONLINE and IS_PIP_PACKAGE:\n try:\n from ultralytics import __version__\n\n latest = check_latest_pypi_version()\n if check_version(__version__, f\"<{latest}\"): # check if current version is < latest version\n LOGGER.info(\n f\"New https://pypi.org/project/ultralytics/{latest} available 😃 \"\n f\"Update with 'pip install -U ultralytics'\"\n )\n return True\n except Exception:\n pass\n return False", "chunk_type": "function", "name": "check_pip_update_available", "file_path": "ultralytics\\ultralytics\\utils\\checks.py", "start_line": 288, "end_line": 308, "start_col": 0, "end_col": 16, "parent_name": null, "docstring": "Check if a new version of the ultralytics package is available on PyPI.\n\nReturns:\n (bool): True if an update is available, False otherwise.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 4, "dependencies": [ "functools", "glob", "inspect", "math", "os", "platform", "re", "shutil", "subprocess", "time", "importlib.metadata", "pathlib.Path", "types.SimpleNamespace", "typing.Optional", "cv2", "numpy", "torch", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.AUTOINSTALL", "ultralytics.utils.IS_COLAB", "ultralytics.utils.IS_GIT_DIR", "ultralytics.utils.IS_JETSON", "ultralytics.utils.IS_KAGGLE", "ultralytics.utils.IS_PIP_PACKAGE", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.MACOS", "ultralytics.utils.ONLINE", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.RKNN_CHIPS", "ultralytics.utils.ROOT", "ultralytics.utils.TORCHVISION_VERSION", "ultralytics.utils.USER_CONFIG_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.Retry", "ultralytics.utils.ThreadingLocked", "ultralytics.utils.TryExcept", "ultralytics.utils.clean_url", "ultralytics.utils.colorstr", "ultralytics.utils.downloads", "ultralytics.utils.is_github_action_running", "ultralytics.utils.url2file", "requests", "matplotlib.font_manager", "psutil", "ultralytics.utils.torch_utils.select_device", "psutil", "ultralytics.utils.ENVIRONMENT", "ultralytics.utils.torch_utils.get_cpu_info", "ultralytics.utils.torch_utils.get_gpu_info", "ultralytics.utils.torch_utils.autocast", "ultralytics.utils.torch_utils.get_cpu_info", "ultralytics.YOLO", "ultralytics.__version__", "IPython.display" ], "chunk_id": "function_check_pip_update_available_406d331e" }, { "content": "def check_font(font=\"Arial.ttf\"):\n \"\"\"\n Find font locally or download to user's configuration directory if it does not already exist.\n\n Args:\n font (str): Path or name of font.\n\n Returns:\n (Path): Resolved font file path.\n \"\"\"\n from matplotlib import font_manager # scope for faster 'import ultralytics'\n\n # Check USER_CONFIG_DIR\n name = Path(font).name\n file = USER_CONFIG_DIR / name\n if file.exists():\n return file\n\n # Check system fonts\n matches = [s for s in font_manager.findSystemFonts() if font in s]\n if any(matches):\n return matches[0]\n\n # Download to USER_CONFIG_DIR if missing\n url = f\"https://github.com/ultralytics/assets/releases/download/v0.0.0/{name}\"\n if downloads.is_url(url, check=True):\n downloads.safe_download(url=url, file=file)\n return file", "chunk_type": "function", "name": "check_font", "file_path": "ultralytics\\ultralytics\\utils\\checks.py", "start_line": 313, "end_line": 340, "start_col": 0, "end_col": 19, "parent_name": null, "docstring": "Find font locally or download to user's configuration directory if it does not already exist.\n\nArgs:\n font (str): Path or name of font.\n\nReturns:\n (Path): Resolved font file path.", "parameters": [ "font" ], "return_type": null, "decorators": [ "ThreadingLocked()", "functools.lru_cache" ], "complexity_score": 5, "dependencies": [ "functools", "glob", "inspect", "math", "os", "platform", "re", "shutil", "subprocess", "time", "importlib.metadata", "pathlib.Path", "types.SimpleNamespace", "typing.Optional", "cv2", "numpy", "torch", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.AUTOINSTALL", "ultralytics.utils.IS_COLAB", "ultralytics.utils.IS_GIT_DIR", "ultralytics.utils.IS_JETSON", "ultralytics.utils.IS_KAGGLE", "ultralytics.utils.IS_PIP_PACKAGE", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.MACOS", "ultralytics.utils.ONLINE", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.RKNN_CHIPS", "ultralytics.utils.ROOT", "ultralytics.utils.TORCHVISION_VERSION", "ultralytics.utils.USER_CONFIG_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.Retry", "ultralytics.utils.ThreadingLocked", "ultralytics.utils.TryExcept", "ultralytics.utils.clean_url", "ultralytics.utils.colorstr", "ultralytics.utils.downloads", "ultralytics.utils.is_github_action_running", "ultralytics.utils.url2file", "requests", "matplotlib.font_manager", "psutil", "ultralytics.utils.torch_utils.select_device", "psutil", "ultralytics.utils.ENVIRONMENT", "ultralytics.utils.torch_utils.get_cpu_info", "ultralytics.utils.torch_utils.get_gpu_info", "ultralytics.utils.torch_utils.autocast", "ultralytics.utils.torch_utils.get_cpu_info", "ultralytics.YOLO", "ultralytics.__version__", "IPython.display" ], "chunk_id": "function_check_font_fd626ed1" }, { "content": "def check_python(minimum: str = \"3.8.0\", hard: bool = True, verbose: bool = False) -> bool:\n \"\"\"\n Check current python version against the required minimum version.\n\n Args:\n minimum (str): Required minimum version of python.\n hard (bool): If True, raise an AssertionError if the requirement is not met.\n verbose (bool): If True, print warning message if requirement is not met.\n\n Returns:\n (bool): Whether the installed Python version meets the minimum constraints.\n \"\"\"\n return check_version(PYTHON_VERSION, minimum, name=\"Python\", hard=hard, verbose=verbose)", "chunk_type": "function", "name": "check_python", "file_path": "ultralytics\\ultralytics\\utils\\checks.py", "start_line": 343, "end_line": 355, "start_col": 0, "end_col": 92, "parent_name": null, "docstring": "Check current python version against the required minimum version.\n\nArgs:\n minimum (str): Required minimum version of python.\n hard (bool): If True, raise an AssertionError if the requirement is not met.\n verbose (bool): If True, print warning message if requirement is not met.\n\nReturns:\n (bool): Whether the installed Python version meets the minimum constraints.", "parameters": [ "minimum: str", "hard: bool", "verbose: bool" ], "return_type": "bool", "decorators": [], "complexity_score": 1, "dependencies": [ "functools", "glob", "inspect", "math", "os", "platform", "re", "shutil", "subprocess", "time", "importlib.metadata", "pathlib.Path", "types.SimpleNamespace", "typing.Optional", "cv2", "numpy", "torch", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.AUTOINSTALL", "ultralytics.utils.IS_COLAB", "ultralytics.utils.IS_GIT_DIR", "ultralytics.utils.IS_JETSON", "ultralytics.utils.IS_KAGGLE", "ultralytics.utils.IS_PIP_PACKAGE", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.MACOS", "ultralytics.utils.ONLINE", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.RKNN_CHIPS", "ultralytics.utils.ROOT", "ultralytics.utils.TORCHVISION_VERSION", "ultralytics.utils.USER_CONFIG_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.Retry", "ultralytics.utils.ThreadingLocked", "ultralytics.utils.TryExcept", "ultralytics.utils.clean_url", "ultralytics.utils.colorstr", "ultralytics.utils.downloads", "ultralytics.utils.is_github_action_running", "ultralytics.utils.url2file", "requests", "matplotlib.font_manager", "psutil", "ultralytics.utils.torch_utils.select_device", "psutil", "ultralytics.utils.ENVIRONMENT", "ultralytics.utils.torch_utils.get_cpu_info", "ultralytics.utils.torch_utils.get_gpu_info", "ultralytics.utils.torch_utils.autocast", "ultralytics.utils.torch_utils.get_cpu_info", "ultralytics.YOLO", "ultralytics.__version__", "IPython.display" ], "chunk_id": "function_check_python_f5574175" }, { "content": "def check_requirements(requirements=ROOT.parent / \"requirements.txt\", exclude=(), install=True, cmds=\"\"):\n \"\"\"\n Check if installed dependencies meet Ultralytics YOLO models requirements and attempt to auto-update if needed.\n\n Args:\n requirements (Path | str | List[str]): Path to a requirements.txt file, a single package requirement as a\n string, or a list of package requirements as strings.\n exclude (tuple): Tuple of package names to exclude from checking.\n install (bool): If True, attempt to auto-update packages that don't meet requirements.\n cmds (str): Additional commands to pass to the pip install command when auto-updating.\n\n Examples:\n >>> from ultralytics.utils.checks import check_requirements\n\n Check a requirements.txt file\n >>> check_requirements(\"path/to/requirements.txt\")\n\n Check a single package\n >>> check_requirements(\"ultralytics>=8.0.0\")\n\n Check multiple packages\n >>> check_requirements([\"numpy\", \"ultralytics>=8.0.0\"])\n \"\"\"\n prefix = colorstr(\"red\", \"bold\", \"requirements:\")\n if isinstance(requirements, Path): # requirements.txt file\n file = requirements.resolve()\n assert file.exists(), f\"{prefix} {file} not found, check failed.\"\n requirements = [f\"{x.name}{x.specifier}\" for x in parse_requirements(file) if x.name not in exclude]\n elif isinstance(requirements, str):\n requirements = [requirements]\n\n pkgs = []\n for r in requirements:\n r_stripped = r.rpartition(\"/\")[-1].replace(\".git\", \"\") # replace git+https://org/repo.git -> 'repo'\n match = re.match(r\"([a-zA-Z0-9-_]+)([<>!=~]+.*)?\", r_stripped)\n name, required = match[1], match[2].strip() if match[2] else \"\"\n try:\n assert check_version(metadata.version(name), required) # exception if requirements not met\n except (AssertionError, metadata.PackageNotFoundError):\n pkgs.append(r)\n\n @Retry(times=2, delay=1)\n def attempt_install(packages, commands, use_uv):\n \"\"\"Attempt package installation with uv if available, falling back to pip.\"\"\"\n if use_uv:\n base = f\"uv pip install --no-cache-dir {packages} {commands} --index-strategy=unsafe-best-match --break-system-packages --prerelease=allow\"\n try:\n return subprocess.check_output(base, shell=True, stderr=subprocess.PIPE).decode()\n except subprocess.CalledProcessError as e:\n if e.stderr and \"No virtual environment found\" in e.stderr.decode():\n return subprocess.check_output(\n base.replace(\"uv pip install\", \"uv pip install --system\"), shell=True\n ).decode()\n raise\n return subprocess.check_output(f\"pip install --no-cache-dir {packages} {commands}\", shell=True).decode()\n\n s = \" \".join(f'\"{x}\"' for x in pkgs) # console string\n if s:\n if install and AUTOINSTALL: # check environment variable\n # Note uv fails on arm64 macOS and Raspberry Pi runners\n n = len(pkgs) # number of packages updates\n LOGGER.info(f\"{prefix} Ultralytics requirement{'s' * (n > 1)} {pkgs} not found, attempting AutoUpdate...\")\n try:\n t = time.time()\n assert ONLINE, \"AutoUpdate skipped (offline)\"\n LOGGER.info(attempt_install(s, cmds, use_uv=not ARM64 and check_uv()))\n dt = time.time() - t\n LOGGER.info(f\"{prefix} AutoUpdate success ✅ {dt:.1f}s\")\n LOGGER.warning(\n f\"{prefix} {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\\n\"\n )\n except Exception as e:\n LOGGER.warning(f\"{prefix} ❌ {e}\")\n return False\n else:\n return False\n\n return True", "chunk_type": "function", "name": "check_requirements", "file_path": "ultralytics\\ultralytics\\utils\\checks.py", "start_line": 359, "end_line": 436, "start_col": 0, "end_col": 15, "parent_name": null, "docstring": "Check if installed dependencies meet Ultralytics YOLO models requirements and attempt to auto-update if needed.\n\nArgs:\n requirements (Path | str | List[str]): Path to a requirements.txt file, a single package requirement as a\n string, or a list of package requirements as strings.\n exclude (tuple): Tuple of package names to exclude from checking.\n install (bool): If True, attempt to auto-update packages that don't meet requirements.\n cmds (str): Additional commands to pass to the pip install command when auto-updating.\n\nExamples:\n >>> from ultralytics.utils.checks import check_requirements\n\n Check a requirements.txt file\n >>> check_requirements(\"path/to/requirements.txt\")\n\n Check a single package\n >>> check_requirements(\"ultralytics>=8.0.0\")\n\n Check multiple packages\n >>> check_requirements([\"numpy\", \"ultralytics>=8.0.0\"])", "parameters": [ "requirements", "exclude", "install", "cmds" ], "return_type": null, "decorators": [ "TryExcept()" ], "complexity_score": 13, "dependencies": [ "functools", "glob", "inspect", "math", "os", "platform", "re", "shutil", "subprocess", "time", "importlib.metadata", "pathlib.Path", "types.SimpleNamespace", "typing.Optional", "cv2", "numpy", "torch", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.AUTOINSTALL", "ultralytics.utils.IS_COLAB", "ultralytics.utils.IS_GIT_DIR", "ultralytics.utils.IS_JETSON", "ultralytics.utils.IS_KAGGLE", "ultralytics.utils.IS_PIP_PACKAGE", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.MACOS", "ultralytics.utils.ONLINE", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.RKNN_CHIPS", "ultralytics.utils.ROOT", "ultralytics.utils.TORCHVISION_VERSION", "ultralytics.utils.USER_CONFIG_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.Retry", "ultralytics.utils.ThreadingLocked", "ultralytics.utils.TryExcept", "ultralytics.utils.clean_url", "ultralytics.utils.colorstr", "ultralytics.utils.downloads", "ultralytics.utils.is_github_action_running", "ultralytics.utils.url2file", "requests", "matplotlib.font_manager", "psutil", "ultralytics.utils.torch_utils.select_device", "psutil", "ultralytics.utils.ENVIRONMENT", "ultralytics.utils.torch_utils.get_cpu_info", "ultralytics.utils.torch_utils.get_gpu_info", "ultralytics.utils.torch_utils.autocast", "ultralytics.utils.torch_utils.get_cpu_info", "ultralytics.YOLO", "ultralytics.__version__", "IPython.display" ], "chunk_id": "function_check_requirements_8db63c68" }, { "content": "def check_torchvision():\n \"\"\"\n Check the installed versions of PyTorch and Torchvision to ensure they're compatible.\n\n This function checks the installed versions of PyTorch and Torchvision, and warns if they're incompatible according\n to the compatibility table based on: https://github.com/pytorch/vision#installation.\n \"\"\"\n compatibility_table = {\n \"2.7\": [\"0.22\"],\n \"2.6\": [\"0.21\"],\n \"2.5\": [\"0.20\"],\n \"2.4\": [\"0.19\"],\n \"2.3\": [\"0.18\"],\n \"2.2\": [\"0.17\"],\n \"2.1\": [\"0.16\"],\n \"2.0\": [\"0.15\"],\n \"1.13\": [\"0.14\"],\n \"1.12\": [\"0.13\"],\n }\n\n # Check major and minor versions\n v_torch = \".\".join(torch.__version__.split(\"+\", 1)[0].split(\".\")[:2])\n if v_torch in compatibility_table:\n compatible_versions = compatibility_table[v_torch]\n v_torchvision = \".\".join(TORCHVISION_VERSION.split(\"+\", 1)[0].split(\".\")[:2])\n if all(v_torchvision != v for v in compatible_versions):\n LOGGER.warning(\n f\"torchvision=={v_torchvision} is incompatible with torch=={v_torch}.\\n\"\n f\"Run 'pip install torchvision=={compatible_versions[0]}' to fix torchvision or \"\n \"'pip install -U torch torchvision' to update both.\\n\"\n \"For a full compatibility table see https://github.com/pytorch/vision#installation\"\n )", "chunk_type": "function", "name": "check_torchvision", "file_path": "ultralytics\\ultralytics\\utils\\checks.py", "start_line": 439, "end_line": 470, "start_col": 0, "end_col": 13, "parent_name": null, "docstring": "Check the installed versions of PyTorch and Torchvision to ensure they're compatible.\n\nThis function checks the installed versions of PyTorch and Torchvision, and warns if they're incompatible according\nto the compatibility table based on: https://github.com/pytorch/vision#installation.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 4, "dependencies": [ "functools", "glob", "inspect", "math", "os", "platform", "re", "shutil", "subprocess", "time", "importlib.metadata", "pathlib.Path", "types.SimpleNamespace", "typing.Optional", "cv2", "numpy", "torch", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.AUTOINSTALL", "ultralytics.utils.IS_COLAB", "ultralytics.utils.IS_GIT_DIR", "ultralytics.utils.IS_JETSON", "ultralytics.utils.IS_KAGGLE", "ultralytics.utils.IS_PIP_PACKAGE", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.MACOS", "ultralytics.utils.ONLINE", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.RKNN_CHIPS", "ultralytics.utils.ROOT", "ultralytics.utils.TORCHVISION_VERSION", "ultralytics.utils.USER_CONFIG_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.Retry", "ultralytics.utils.ThreadingLocked", "ultralytics.utils.TryExcept", "ultralytics.utils.clean_url", "ultralytics.utils.colorstr", "ultralytics.utils.downloads", "ultralytics.utils.is_github_action_running", "ultralytics.utils.url2file", "requests", "matplotlib.font_manager", "psutil", "ultralytics.utils.torch_utils.select_device", "psutil", "ultralytics.utils.ENVIRONMENT", "ultralytics.utils.torch_utils.get_cpu_info", "ultralytics.utils.torch_utils.get_gpu_info", "ultralytics.utils.torch_utils.autocast", "ultralytics.utils.torch_utils.get_cpu_info", "ultralytics.YOLO", "ultralytics.__version__", "IPython.display" ], "chunk_id": "function_check_torchvision_9377ec30" }, { "content": "def check_suffix(file=\"yolo11n.pt\", suffix=\".pt\", msg=\"\"):\n \"\"\"\n Check file(s) for acceptable suffix.\n\n Args:\n file (str | List[str]): File or list of files to check.\n suffix (str | tuple): Acceptable suffix or tuple of suffixes.\n msg (str): Additional message to display in case of error.\n \"\"\"\n if file and suffix:\n if isinstance(suffix, str):\n suffix = {suffix}\n for f in file if isinstance(file, (list, tuple)) else [file]:\n if s := str(f).rpartition(\".\")[-1].lower().strip(): # file suffix\n assert f\".{s}\" in suffix, f\"{msg}{f} acceptable suffix is {suffix}, not .{s}\"", "chunk_type": "function", "name": "check_suffix", "file_path": "ultralytics\\ultralytics\\utils\\checks.py", "start_line": 473, "end_line": 487, "start_col": 0, "end_col": 93, "parent_name": null, "docstring": "Check file(s) for acceptable suffix.\n\nArgs:\n file (str | List[str]): File or list of files to check.\n suffix (str | tuple): Acceptable suffix or tuple of suffixes.\n msg (str): Additional message to display in case of error.", "parameters": [ "file", "suffix", "msg" ], "return_type": null, "decorators": [], "complexity_score": 5, "dependencies": [ "functools", "glob", "inspect", "math", "os", "platform", "re", "shutil", "subprocess", "time", "importlib.metadata", "pathlib.Path", "types.SimpleNamespace", "typing.Optional", "cv2", "numpy", "torch", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.AUTOINSTALL", "ultralytics.utils.IS_COLAB", "ultralytics.utils.IS_GIT_DIR", "ultralytics.utils.IS_JETSON", "ultralytics.utils.IS_KAGGLE", "ultralytics.utils.IS_PIP_PACKAGE", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.MACOS", "ultralytics.utils.ONLINE", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.RKNN_CHIPS", "ultralytics.utils.ROOT", "ultralytics.utils.TORCHVISION_VERSION", "ultralytics.utils.USER_CONFIG_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.Retry", "ultralytics.utils.ThreadingLocked", "ultralytics.utils.TryExcept", "ultralytics.utils.clean_url", "ultralytics.utils.colorstr", "ultralytics.utils.downloads", "ultralytics.utils.is_github_action_running", "ultralytics.utils.url2file", "requests", "matplotlib.font_manager", "psutil", "ultralytics.utils.torch_utils.select_device", "psutil", "ultralytics.utils.ENVIRONMENT", "ultralytics.utils.torch_utils.get_cpu_info", "ultralytics.utils.torch_utils.get_gpu_info", "ultralytics.utils.torch_utils.autocast", "ultralytics.utils.torch_utils.get_cpu_info", "ultralytics.YOLO", "ultralytics.__version__", "IPython.display" ], "chunk_id": "function_check_suffix_2c408922" }, { "content": "def check_yolov5u_filename(file: str, verbose: bool = True):\n \"\"\"\n Replace legacy YOLOv5 filenames with updated YOLOv5u filenames.\n\n Args:\n file (str): Filename to check and potentially update.\n verbose (bool): Whether to print information about the replacement.\n\n Returns:\n (str): Updated filename.\n \"\"\"\n if \"yolov3\" in file or \"yolov5\" in file:\n if \"u.yaml\" in file:\n file = file.replace(\"u.yaml\", \".yaml\") # i.e. yolov5nu.yaml -> yolov5n.yaml\n elif \".pt\" in file and \"u\" not in file:\n original_file = file\n file = re.sub(r\"(.*yolov5([nsmlx]))\\.pt\", \"\\\\1u.pt\", file) # i.e. yolov5n.pt -> yolov5nu.pt\n file = re.sub(r\"(.*yolov5([nsmlx])6)\\.pt\", \"\\\\1u.pt\", file) # i.e. yolov5n6.pt -> yolov5n6u.pt\n file = re.sub(r\"(.*yolov3(|-tiny|-spp))\\.pt\", \"\\\\1u.pt\", file) # i.e. yolov3-spp.pt -> yolov3-sppu.pt\n if file != original_file and verbose:\n LOGGER.info(\n f\"PRO TIP 💡 Replace 'model={original_file}' with new 'model={file}'.\\nYOLOv5 'u' models are \"\n f\"trained with https://github.com/ultralytics/ultralytics and feature improved performance vs \"\n f\"standard YOLOv5 models trained with https://github.com/ultralytics/yolov5.\\n\"\n )\n return file", "chunk_type": "function", "name": "check_yolov5u_filename", "file_path": "ultralytics\\ultralytics\\utils\\checks.py", "start_line": 490, "end_line": 515, "start_col": 0, "end_col": 15, "parent_name": null, "docstring": "Replace legacy YOLOv5 filenames with updated YOLOv5u filenames.\n\nArgs:\n file (str): Filename to check and potentially update.\n verbose (bool): Whether to print information about the replacement.\n\nReturns:\n (str): Updated filename.", "parameters": [ "file: str", "verbose: bool" ], "return_type": null, "decorators": [], "complexity_score": 5, "dependencies": [ "functools", "glob", "inspect", "math", "os", "platform", "re", "shutil", "subprocess", "time", "importlib.metadata", "pathlib.Path", "types.SimpleNamespace", "typing.Optional", "cv2", "numpy", "torch", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.AUTOINSTALL", "ultralytics.utils.IS_COLAB", "ultralytics.utils.IS_GIT_DIR", "ultralytics.utils.IS_JETSON", "ultralytics.utils.IS_KAGGLE", "ultralytics.utils.IS_PIP_PACKAGE", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.MACOS", "ultralytics.utils.ONLINE", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.RKNN_CHIPS", "ultralytics.utils.ROOT", "ultralytics.utils.TORCHVISION_VERSION", "ultralytics.utils.USER_CONFIG_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.Retry", "ultralytics.utils.ThreadingLocked", "ultralytics.utils.TryExcept", "ultralytics.utils.clean_url", "ultralytics.utils.colorstr", "ultralytics.utils.downloads", "ultralytics.utils.is_github_action_running", "ultralytics.utils.url2file", "requests", "matplotlib.font_manager", "psutil", "ultralytics.utils.torch_utils.select_device", "psutil", "ultralytics.utils.ENVIRONMENT", "ultralytics.utils.torch_utils.get_cpu_info", "ultralytics.utils.torch_utils.get_gpu_info", "ultralytics.utils.torch_utils.autocast", "ultralytics.utils.torch_utils.get_cpu_info", "ultralytics.YOLO", "ultralytics.__version__", "IPython.display" ], "chunk_id": "function_check_yolov5u_filename_c6c07a40" }, { "content": "def check_model_file_from_stem(model=\"yolo11n\"):\n \"\"\"\n Return a model filename from a valid model stem.\n\n Args:\n model (str): Model stem to check.\n\n Returns:\n (str | Path): Model filename with appropriate suffix.\n \"\"\"\n path = Path(model)\n if not path.suffix and path.stem in downloads.GITHUB_ASSETS_STEMS:\n return path.with_suffix(\".pt\") # add suffix, i.e. yolo11n -> yolo11n.pt\n return model", "chunk_type": "function", "name": "check_model_file_from_stem", "file_path": "ultralytics\\ultralytics\\utils\\checks.py", "start_line": 518, "end_line": 531, "start_col": 0, "end_col": 16, "parent_name": null, "docstring": "Return a model filename from a valid model stem.\n\nArgs:\n model (str): Model stem to check.\n\nReturns:\n (str | Path): Model filename with appropriate suffix.", "parameters": [ "model" ], "return_type": null, "decorators": [], "complexity_score": 2, "dependencies": [ "functools", "glob", "inspect", "math", "os", "platform", "re", "shutil", "subprocess", "time", "importlib.metadata", "pathlib.Path", "types.SimpleNamespace", "typing.Optional", "cv2", "numpy", "torch", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.AUTOINSTALL", "ultralytics.utils.IS_COLAB", "ultralytics.utils.IS_GIT_DIR", "ultralytics.utils.IS_JETSON", "ultralytics.utils.IS_KAGGLE", "ultralytics.utils.IS_PIP_PACKAGE", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.MACOS", "ultralytics.utils.ONLINE", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.RKNN_CHIPS", "ultralytics.utils.ROOT", "ultralytics.utils.TORCHVISION_VERSION", "ultralytics.utils.USER_CONFIG_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.Retry", "ultralytics.utils.ThreadingLocked", "ultralytics.utils.TryExcept", "ultralytics.utils.clean_url", "ultralytics.utils.colorstr", "ultralytics.utils.downloads", "ultralytics.utils.is_github_action_running", "ultralytics.utils.url2file", "requests", "matplotlib.font_manager", "psutil", "ultralytics.utils.torch_utils.select_device", "psutil", "ultralytics.utils.ENVIRONMENT", "ultralytics.utils.torch_utils.get_cpu_info", "ultralytics.utils.torch_utils.get_gpu_info", "ultralytics.utils.torch_utils.autocast", "ultralytics.utils.torch_utils.get_cpu_info", "ultralytics.YOLO", "ultralytics.__version__", "IPython.display" ], "chunk_id": "function_check_model_file_from_stem_585ca3c7" }, { "content": "def check_file(file, suffix=\"\", download=True, download_dir=\".\", hard=True):\n \"\"\"\n Search/download file (if necessary), check suffix (if provided), and return path.\n\n Args:\n file (str): File name or path.\n suffix (str | tuple): Acceptable suffix or tuple of suffixes to validate against the file.\n download (bool): Whether to download the file if it doesn't exist locally.\n download_dir (str): Directory to download the file to.\n hard (bool): Whether to raise an error if the file is not found.\n\n Returns:\n (str): Path to the file.\n \"\"\"\n check_suffix(file, suffix) # optional\n file = str(file).strip() # convert to string and strip spaces\n file = check_yolov5u_filename(file) # yolov5n -> yolov5nu\n if (\n not file\n or (\"://\" not in file and Path(file).exists()) # '://' check required in Windows Python<3.10\n or file.lower().startswith(\"grpc://\")\n ): # file exists or gRPC Triton images\n return file\n elif download and file.lower().startswith((\"https://\", \"http://\", \"rtsp://\", \"rtmp://\", \"tcp://\")): # download\n url = file # warning: Pathlib turns :// -> :/\n file = Path(download_dir) / url2file(file) # '%2F' to '/', split https://url.com/file.txt?auth\n if file.exists():\n LOGGER.info(f\"Found {clean_url(url)} locally at {file}\") # file already exists\n else:\n downloads.safe_download(url=url, file=file, unzip=False)\n return str(file)\n else: # search\n files = glob.glob(str(ROOT / \"**\" / file), recursive=True) or glob.glob(str(ROOT.parent / file)) # find file\n if not files and hard:\n raise FileNotFoundError(f\"'{file}' does not exist\")\n elif len(files) > 1 and hard:\n raise FileNotFoundError(f\"Multiple files match '{file}', specify exact path: {files}\")\n return files[0] if len(files) else [] # return file", "chunk_type": "function", "name": "check_file", "file_path": "ultralytics\\ultralytics\\utils\\checks.py", "start_line": 534, "end_line": 571, "start_col": 0, "end_col": 45, "parent_name": null, "docstring": "Search/download file (if necessary), check suffix (if provided), and return path.\n\nArgs:\n file (str): File name or path.\n suffix (str | tuple): Acceptable suffix or tuple of suffixes to validate against the file.\n download (bool): Whether to download the file if it doesn't exist locally.\n download_dir (str): Directory to download the file to.\n hard (bool): Whether to raise an error if the file is not found.\n\nReturns:\n (str): Path to the file.", "parameters": [ "file", "suffix", "download", "download_dir", "hard" ], "return_type": null, "decorators": [], "complexity_score": 6, "dependencies": [ "functools", "glob", "inspect", "math", "os", "platform", "re", "shutil", "subprocess", "time", "importlib.metadata", "pathlib.Path", "types.SimpleNamespace", "typing.Optional", "cv2", "numpy", "torch", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.AUTOINSTALL", "ultralytics.utils.IS_COLAB", "ultralytics.utils.IS_GIT_DIR", "ultralytics.utils.IS_JETSON", "ultralytics.utils.IS_KAGGLE", "ultralytics.utils.IS_PIP_PACKAGE", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.MACOS", "ultralytics.utils.ONLINE", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.RKNN_CHIPS", "ultralytics.utils.ROOT", "ultralytics.utils.TORCHVISION_VERSION", "ultralytics.utils.USER_CONFIG_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.Retry", "ultralytics.utils.ThreadingLocked", "ultralytics.utils.TryExcept", "ultralytics.utils.clean_url", "ultralytics.utils.colorstr", "ultralytics.utils.downloads", "ultralytics.utils.is_github_action_running", "ultralytics.utils.url2file", "requests", "matplotlib.font_manager", "psutil", "ultralytics.utils.torch_utils.select_device", "psutil", "ultralytics.utils.ENVIRONMENT", "ultralytics.utils.torch_utils.get_cpu_info", "ultralytics.utils.torch_utils.get_gpu_info", "ultralytics.utils.torch_utils.autocast", "ultralytics.utils.torch_utils.get_cpu_info", "ultralytics.YOLO", "ultralytics.__version__", "IPython.display" ], "chunk_id": "function_check_file_f59712e4" }, { "content": "def check_yaml(file, suffix=(\".yaml\", \".yml\"), hard=True):\n \"\"\"\n Search/download YAML file (if necessary) and return path, checking suffix.\n\n Args:\n file (str | Path): File name or path.\n suffix (tuple): Tuple of acceptable YAML file suffixes.\n hard (bool): Whether to raise an error if the file is not found or multiple files are found.\n\n Returns:\n (str): Path to the YAML file.\n \"\"\"\n return check_file(file, suffix, hard=hard)", "chunk_type": "function", "name": "check_yaml", "file_path": "ultralytics\\ultralytics\\utils\\checks.py", "start_line": 574, "end_line": 586, "start_col": 0, "end_col": 46, "parent_name": null, "docstring": "Search/download YAML file (if necessary) and return path, checking suffix.\n\nArgs:\n file (str | Path): File name or path.\n suffix (tuple): Tuple of acceptable YAML file suffixes.\n hard (bool): Whether to raise an error if the file is not found or multiple files are found.\n\nReturns:\n (str): Path to the YAML file.", "parameters": [ "file", "suffix", "hard" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "functools", "glob", "inspect", "math", "os", "platform", "re", "shutil", "subprocess", "time", "importlib.metadata", "pathlib.Path", "types.SimpleNamespace", "typing.Optional", "cv2", "numpy", "torch", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.AUTOINSTALL", "ultralytics.utils.IS_COLAB", "ultralytics.utils.IS_GIT_DIR", "ultralytics.utils.IS_JETSON", "ultralytics.utils.IS_KAGGLE", "ultralytics.utils.IS_PIP_PACKAGE", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.MACOS", "ultralytics.utils.ONLINE", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.RKNN_CHIPS", "ultralytics.utils.ROOT", "ultralytics.utils.TORCHVISION_VERSION", "ultralytics.utils.USER_CONFIG_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.Retry", "ultralytics.utils.ThreadingLocked", "ultralytics.utils.TryExcept", "ultralytics.utils.clean_url", "ultralytics.utils.colorstr", "ultralytics.utils.downloads", "ultralytics.utils.is_github_action_running", "ultralytics.utils.url2file", "requests", "matplotlib.font_manager", "psutil", "ultralytics.utils.torch_utils.select_device", "psutil", "ultralytics.utils.ENVIRONMENT", "ultralytics.utils.torch_utils.get_cpu_info", "ultralytics.utils.torch_utils.get_gpu_info", "ultralytics.utils.torch_utils.autocast", "ultralytics.utils.torch_utils.get_cpu_info", "ultralytics.YOLO", "ultralytics.__version__", "IPython.display" ], "chunk_id": "function_check_yaml_3af24e9b" }, { "content": "def check_is_path_safe(basedir, path):\n \"\"\"\n Check if the resolved path is under the intended directory to prevent path traversal.\n\n Args:\n basedir (Path | str): The intended directory.\n path (Path | str): The path to check.\n\n Returns:\n (bool): True if the path is safe, False otherwise.\n \"\"\"\n base_dir_resolved = Path(basedir).resolve()\n path_resolved = Path(path).resolve()\n\n return path_resolved.exists() and path_resolved.parts[: len(base_dir_resolved.parts)] == base_dir_resolved.parts", "chunk_type": "function", "name": "check_is_path_safe", "file_path": "ultralytics\\ultralytics\\utils\\checks.py", "start_line": 589, "end_line": 603, "start_col": 0, "end_col": 116, "parent_name": null, "docstring": "Check if the resolved path is under the intended directory to prevent path traversal.\n\nArgs:\n basedir (Path | str): The intended directory.\n path (Path | str): The path to check.\n\nReturns:\n (bool): True if the path is safe, False otherwise.", "parameters": [ "basedir", "path" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "functools", "glob", "inspect", "math", "os", "platform", "re", "shutil", "subprocess", "time", "importlib.metadata", "pathlib.Path", "types.SimpleNamespace", "typing.Optional", "cv2", "numpy", "torch", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.AUTOINSTALL", "ultralytics.utils.IS_COLAB", "ultralytics.utils.IS_GIT_DIR", "ultralytics.utils.IS_JETSON", "ultralytics.utils.IS_KAGGLE", "ultralytics.utils.IS_PIP_PACKAGE", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.MACOS", "ultralytics.utils.ONLINE", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.RKNN_CHIPS", "ultralytics.utils.ROOT", "ultralytics.utils.TORCHVISION_VERSION", "ultralytics.utils.USER_CONFIG_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.Retry", "ultralytics.utils.ThreadingLocked", "ultralytics.utils.TryExcept", "ultralytics.utils.clean_url", "ultralytics.utils.colorstr", "ultralytics.utils.downloads", "ultralytics.utils.is_github_action_running", "ultralytics.utils.url2file", "requests", "matplotlib.font_manager", "psutil", "ultralytics.utils.torch_utils.select_device", "psutil", "ultralytics.utils.ENVIRONMENT", "ultralytics.utils.torch_utils.get_cpu_info", "ultralytics.utils.torch_utils.get_gpu_info", "ultralytics.utils.torch_utils.autocast", "ultralytics.utils.torch_utils.get_cpu_info", "ultralytics.YOLO", "ultralytics.__version__", "IPython.display" ], "chunk_id": "function_check_is_path_safe_1b44c9a8" }, { "content": "def check_imshow(warn=False):\n \"\"\"\n Check if environment supports image displays.\n\n Args:\n warn (bool): Whether to warn if environment doesn't support image displays.\n\n Returns:\n (bool): True if environment supports image displays, False otherwise.\n \"\"\"\n try:\n if LINUX:\n assert not IS_COLAB and not IS_KAGGLE\n assert \"DISPLAY\" in os.environ, \"The DISPLAY environment variable isn't set.\"\n cv2.imshow(\"test\", np.zeros((8, 8, 3), dtype=np.uint8)) # show a small 8-pixel image\n cv2.waitKey(1)\n cv2.destroyAllWindows()\n cv2.waitKey(1)\n return True\n except Exception as e:\n if warn:\n LOGGER.warning(f\"Environment does not support cv2.imshow() or PIL Image.show()\\n{e}\")\n return False", "chunk_type": "function", "name": "check_imshow", "file_path": "ultralytics\\ultralytics\\utils\\checks.py", "start_line": 607, "end_line": 629, "start_col": 0, "end_col": 20, "parent_name": null, "docstring": "Check if environment supports image displays.\n\nArgs:\n warn (bool): Whether to warn if environment doesn't support image displays.\n\nReturns:\n (bool): True if environment supports image displays, False otherwise.", "parameters": [ "warn" ], "return_type": null, "decorators": [ "functools.lru_cache" ], "complexity_score": 4, "dependencies": [ "functools", "glob", "inspect", "math", "os", "platform", "re", "shutil", "subprocess", "time", "importlib.metadata", "pathlib.Path", "types.SimpleNamespace", "typing.Optional", "cv2", "numpy", "torch", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.AUTOINSTALL", "ultralytics.utils.IS_COLAB", "ultralytics.utils.IS_GIT_DIR", "ultralytics.utils.IS_JETSON", "ultralytics.utils.IS_KAGGLE", "ultralytics.utils.IS_PIP_PACKAGE", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.MACOS", "ultralytics.utils.ONLINE", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.RKNN_CHIPS", "ultralytics.utils.ROOT", "ultralytics.utils.TORCHVISION_VERSION", "ultralytics.utils.USER_CONFIG_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.Retry", "ultralytics.utils.ThreadingLocked", "ultralytics.utils.TryExcept", "ultralytics.utils.clean_url", "ultralytics.utils.colorstr", "ultralytics.utils.downloads", "ultralytics.utils.is_github_action_running", "ultralytics.utils.url2file", "requests", "matplotlib.font_manager", "psutil", "ultralytics.utils.torch_utils.select_device", "psutil", "ultralytics.utils.ENVIRONMENT", "ultralytics.utils.torch_utils.get_cpu_info", "ultralytics.utils.torch_utils.get_gpu_info", "ultralytics.utils.torch_utils.autocast", "ultralytics.utils.torch_utils.get_cpu_info", "ultralytics.YOLO", "ultralytics.__version__", "IPython.display" ], "chunk_id": "function_check_imshow_b8705279" }, { "content": "def check_yolo(verbose=True, device=\"\"):\n \"\"\"\n Return a human-readable YOLO software and hardware summary.\n\n Args:\n verbose (bool): Whether to print verbose information.\n device (str | torch.device): Device to use for YOLO.\n \"\"\"\n import psutil\n\n from ultralytics.utils.torch_utils import select_device\n\n if IS_COLAB:\n shutil.rmtree(\"sample_data\", ignore_errors=True) # remove colab /sample_data directory\n\n if verbose:\n # System info\n gib = 1 << 30 # bytes per GiB\n ram = psutil.virtual_memory().total\n total, used, free = shutil.disk_usage(\"/\")\n s = f\"({os.cpu_count()} CPUs, {ram / gib:.1f} GB RAM, {(total - free) / gib:.1f}/{total / gib:.1f} GB disk)\"\n try:\n from IPython import display\n\n display.clear_output() # clear display if notebook\n except ImportError:\n pass\n else:\n s = \"\"\n\n select_device(device=device, newline=False)\n LOGGER.info(f\"Setup complete ✅ {s}\")", "chunk_type": "function", "name": "check_yolo", "file_path": "ultralytics\\ultralytics\\utils\\checks.py", "start_line": 632, "end_line": 663, "start_col": 0, "end_col": 42, "parent_name": null, "docstring": "Return a human-readable YOLO software and hardware summary.\n\nArgs:\n verbose (bool): Whether to print verbose information.\n device (str | torch.device): Device to use for YOLO.", "parameters": [ "verbose", "device" ], "return_type": null, "decorators": [], "complexity_score": 4, "dependencies": [ "functools", "glob", "inspect", "math", "os", "platform", "re", "shutil", "subprocess", "time", "importlib.metadata", "pathlib.Path", "types.SimpleNamespace", "typing.Optional", "cv2", "numpy", "torch", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.AUTOINSTALL", "ultralytics.utils.IS_COLAB", "ultralytics.utils.IS_GIT_DIR", "ultralytics.utils.IS_JETSON", "ultralytics.utils.IS_KAGGLE", "ultralytics.utils.IS_PIP_PACKAGE", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.MACOS", "ultralytics.utils.ONLINE", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.RKNN_CHIPS", "ultralytics.utils.ROOT", "ultralytics.utils.TORCHVISION_VERSION", "ultralytics.utils.USER_CONFIG_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.Retry", "ultralytics.utils.ThreadingLocked", "ultralytics.utils.TryExcept", "ultralytics.utils.clean_url", "ultralytics.utils.colorstr", "ultralytics.utils.downloads", "ultralytics.utils.is_github_action_running", "ultralytics.utils.url2file", "requests", "matplotlib.font_manager", "psutil", "ultralytics.utils.torch_utils.select_device", "psutil", "ultralytics.utils.ENVIRONMENT", "ultralytics.utils.torch_utils.get_cpu_info", "ultralytics.utils.torch_utils.get_gpu_info", "ultralytics.utils.torch_utils.autocast", "ultralytics.utils.torch_utils.get_cpu_info", "ultralytics.YOLO", "ultralytics.__version__", "IPython.display" ], "chunk_id": "function_check_yolo_7a025cfe" }, { "content": "def collect_system_info():\n \"\"\"\n Collect and print relevant system information including OS, Python, RAM, CPU, and CUDA.\n\n Returns:\n (dict): Dictionary containing system information.\n \"\"\"\n import psutil\n\n from ultralytics.utils import ENVIRONMENT # scope to avoid circular import\n from ultralytics.utils.torch_utils import get_cpu_info, get_gpu_info\n\n gib = 1 << 30 # bytes per GiB\n cuda = torch.cuda.is_available()\n check_yolo()\n total, used, free = shutil.disk_usage(\"/\")\n\n info_dict = {\n \"OS\": platform.platform(),\n \"Environment\": ENVIRONMENT,\n \"Python\": PYTHON_VERSION,\n \"Install\": \"git\" if IS_GIT_DIR else \"pip\" if IS_PIP_PACKAGE else \"other\",\n \"Path\": str(ROOT),\n \"RAM\": f\"{psutil.virtual_memory().total / gib:.2f} GB\",\n \"Disk\": f\"{(total - free) / gib:.1f}/{total / gib:.1f} GB\",\n \"CPU\": get_cpu_info(),\n \"CPU count\": os.cpu_count(),\n \"GPU\": get_gpu_info(index=0) if cuda else None,\n \"GPU count\": torch.cuda.device_count() if cuda else None,\n \"CUDA\": torch.version.cuda if cuda else None,\n }\n LOGGER.info(\"\\n\" + \"\\n\".join(f\"{k:<20}{v}\" for k, v in info_dict.items()) + \"\\n\")\n\n package_info = {}\n for r in parse_requirements(package=\"ultralytics\"):\n try:\n current = metadata.version(r.name)\n is_met = \"✅ \" if check_version(current, str(r.specifier), name=r.name, hard=True) else \"❌ \"\n except metadata.PackageNotFoundError:\n current = \"(not installed)\"\n is_met = \"❌ \"\n package_info[r.name] = f\"{is_met}{current}{r.specifier}\"\n LOGGER.info(f\"{r.name:<20}{package_info[r.name]}\")\n\n info_dict[\"Package Info\"] = package_info\n\n if is_github_action_running():\n github_info = {\n \"RUNNER_OS\": os.getenv(\"RUNNER_OS\"),\n \"GITHUB_EVENT_NAME\": os.getenv(\"GITHUB_EVENT_NAME\"),\n \"GITHUB_WORKFLOW\": os.getenv(\"GITHUB_WORKFLOW\"),\n \"GITHUB_ACTOR\": os.getenv(\"GITHUB_ACTOR\"),\n \"GITHUB_REPOSITORY\": os.getenv(\"GITHUB_REPOSITORY\"),\n \"GITHUB_REPOSITORY_OWNER\": os.getenv(\"GITHUB_REPOSITORY_OWNER\"),\n }\n LOGGER.info(\"\\n\" + \"\\n\".join(f\"{k}: {v}\" for k, v in github_info.items()))\n info_dict[\"GitHub Info\"] = github_info\n\n return info_dict", "chunk_type": "function", "name": "collect_system_info", "file_path": "ultralytics\\ultralytics\\utils\\checks.py", "start_line": 666, "end_line": 724, "start_col": 0, "end_col": 20, "parent_name": null, "docstring": "Collect and print relevant system information including OS, Python, RAM, CPU, and CUDA.\n\nReturns:\n (dict): Dictionary containing system information.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 6, "dependencies": [ "functools", "glob", "inspect", "math", "os", "platform", "re", "shutil", "subprocess", "time", "importlib.metadata", "pathlib.Path", "types.SimpleNamespace", "typing.Optional", "cv2", "numpy", "torch", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.AUTOINSTALL", "ultralytics.utils.IS_COLAB", "ultralytics.utils.IS_GIT_DIR", "ultralytics.utils.IS_JETSON", "ultralytics.utils.IS_KAGGLE", "ultralytics.utils.IS_PIP_PACKAGE", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.MACOS", "ultralytics.utils.ONLINE", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.RKNN_CHIPS", "ultralytics.utils.ROOT", "ultralytics.utils.TORCHVISION_VERSION", "ultralytics.utils.USER_CONFIG_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.Retry", "ultralytics.utils.ThreadingLocked", "ultralytics.utils.TryExcept", "ultralytics.utils.clean_url", "ultralytics.utils.colorstr", "ultralytics.utils.downloads", "ultralytics.utils.is_github_action_running", "ultralytics.utils.url2file", "requests", "matplotlib.font_manager", "psutil", "ultralytics.utils.torch_utils.select_device", "psutil", "ultralytics.utils.ENVIRONMENT", "ultralytics.utils.torch_utils.get_cpu_info", "ultralytics.utils.torch_utils.get_gpu_info", "ultralytics.utils.torch_utils.autocast", "ultralytics.utils.torch_utils.get_cpu_info", "ultralytics.YOLO", "ultralytics.__version__", "IPython.display" ], "chunk_id": "function_collect_system_info_5319b4d4" }, { "content": "def check_amp(model):\n \"\"\"\n Check the PyTorch Automatic Mixed Precision (AMP) functionality of a YOLO model.\n\n If the checks fail, it means there are anomalies with AMP on the system that may cause NaN losses or zero-mAP\n results, so AMP will be disabled during training.\n\n Args:\n model (torch.nn.Module): A YOLO model instance.\n\n Returns:\n (bool): Returns True if the AMP functionality works correctly with YOLO11 model, else False.\n\n Examples:\n >>> from ultralytics import YOLO\n >>> from ultralytics.utils.checks import check_amp\n >>> model = YOLO(\"yolo11n.pt\").model.cuda()\n >>> check_amp(model)\n \"\"\"\n from ultralytics.utils.torch_utils import autocast\n\n device = next(model.parameters()).device # get model device\n prefix = colorstr(\"AMP: \")\n if device.type in {\"cpu\", \"mps\"}:\n return False # AMP only used on CUDA devices\n else:\n # GPUs that have issues with AMP\n pattern = re.compile(\n r\"(nvidia|geforce|quadro|tesla).*?(1660|1650|1630|t400|t550|t600|t1000|t1200|t2000|k40m)\", re.IGNORECASE\n )\n\n gpu = torch.cuda.get_device_name(device)\n if bool(pattern.search(gpu)):\n LOGGER.warning(\n f\"{prefix}checks failed ❌. AMP training on {gpu} GPU may cause \"\n f\"NaN losses or zero-mAP results, so AMP will be disabled during training.\"\n )\n return False\n\n def amp_allclose(m, im):\n \"\"\"All close FP32 vs AMP results.\"\"\"\n batch = [im] * 8\n imgsz = max(256, int(model.stride.max() * 4)) # max stride P5-32 and P6-64\n a = m(batch, imgsz=imgsz, device=device, verbose=False)[0].boxes.data # FP32 inference\n with autocast(enabled=True):\n b = m(batch, imgsz=imgsz, device=device, verbose=False)[0].boxes.data # AMP inference\n del m\n return a.shape == b.shape and torch.allclose(a, b.float(), atol=0.5) # close to 0.5 absolute tolerance\n\n im = ASSETS / \"bus.jpg\" # image to check\n LOGGER.info(f\"{prefix}running Automatic Mixed Precision (AMP) checks...\")\n warning_msg = \"Setting 'amp=True'. If you experience zero-mAP or NaN losses you can disable AMP with amp=False.\"\n try:\n from ultralytics import YOLO\n\n assert amp_allclose(YOLO(\"yolo11n.pt\"), im)\n LOGGER.info(f\"{prefix}checks passed ✅\")\n except ConnectionError:\n LOGGER.warning(f\"{prefix}checks skipped. Offline and unable to download YOLO11n for AMP checks. {warning_msg}\")\n except (AttributeError, ModuleNotFoundError):\n LOGGER.warning(\n f\"{prefix}checks skipped. \"\n f\"Unable to load YOLO11n for AMP checks due to possible Ultralytics package modifications. {warning_msg}\"\n )\n except AssertionError:\n LOGGER.error(\n f\"{prefix}checks failed. Anomalies were detected with AMP on your system that may lead to \"\n f\"NaN losses or zero-mAP results, so AMP will be disabled during training.\"\n )\n return False\n return True", "chunk_type": "function", "name": "check_amp", "file_path": "ultralytics\\ultralytics\\utils\\checks.py", "start_line": 727, "end_line": 797, "start_col": 0, "end_col": 15, "parent_name": null, "docstring": "Check the PyTorch Automatic Mixed Precision (AMP) functionality of a YOLO model.\n\nIf the checks fail, it means there are anomalies with AMP on the system that may cause NaN losses or zero-mAP\nresults, so AMP will be disabled during training.\n\nArgs:\n model (torch.nn.Module): A YOLO model instance.\n\nReturns:\n (bool): Returns True if the AMP functionality works correctly with YOLO11 model, else False.\n\nExamples:\n >>> from ultralytics import YOLO\n >>> from ultralytics.utils.checks import check_amp\n >>> model = YOLO(\"yolo11n.pt\").model.cuda()\n >>> check_amp(model)", "parameters": [ "model" ], "return_type": null, "decorators": [], "complexity_score": 6, "dependencies": [ "functools", "glob", "inspect", "math", "os", "platform", "re", "shutil", "subprocess", "time", "importlib.metadata", "pathlib.Path", "types.SimpleNamespace", "typing.Optional", "cv2", "numpy", "torch", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.AUTOINSTALL", "ultralytics.utils.IS_COLAB", "ultralytics.utils.IS_GIT_DIR", "ultralytics.utils.IS_JETSON", "ultralytics.utils.IS_KAGGLE", "ultralytics.utils.IS_PIP_PACKAGE", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.MACOS", "ultralytics.utils.ONLINE", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.RKNN_CHIPS", "ultralytics.utils.ROOT", "ultralytics.utils.TORCHVISION_VERSION", "ultralytics.utils.USER_CONFIG_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.Retry", "ultralytics.utils.ThreadingLocked", "ultralytics.utils.TryExcept", "ultralytics.utils.clean_url", "ultralytics.utils.colorstr", "ultralytics.utils.downloads", "ultralytics.utils.is_github_action_running", "ultralytics.utils.url2file", "requests", "matplotlib.font_manager", "psutil", "ultralytics.utils.torch_utils.select_device", "psutil", "ultralytics.utils.ENVIRONMENT", "ultralytics.utils.torch_utils.get_cpu_info", "ultralytics.utils.torch_utils.get_gpu_info", "ultralytics.utils.torch_utils.autocast", "ultralytics.utils.torch_utils.get_cpu_info", "ultralytics.YOLO", "ultralytics.__version__", "IPython.display" ], "chunk_id": "function_check_amp_5eaf9c64" }, { "content": "def git_describe(path=ROOT): # path must be a directory\n \"\"\"\n Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe.\n\n Args:\n path (Path): Path to git repository.\n\n Returns:\n (str): Human-readable git description.\n \"\"\"\n try:\n return subprocess.check_output(f\"git -C {path} describe --tags --long --always\", shell=True).decode()[:-1]\n except Exception:\n return \"\"", "chunk_type": "function", "name": "git_describe", "file_path": "ultralytics\\ultralytics\\utils\\checks.py", "start_line": 800, "end_line": 813, "start_col": 0, "end_col": 17, "parent_name": null, "docstring": "Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe.\n\nArgs:\n path (Path): Path to git repository.\n\nReturns:\n (str): Human-readable git description.", "parameters": [ "path" ], "return_type": null, "decorators": [], "complexity_score": 2, "dependencies": [ "functools", "glob", "inspect", "math", "os", "platform", "re", "shutil", "subprocess", "time", "importlib.metadata", "pathlib.Path", "types.SimpleNamespace", "typing.Optional", "cv2", "numpy", "torch", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.AUTOINSTALL", "ultralytics.utils.IS_COLAB", "ultralytics.utils.IS_GIT_DIR", "ultralytics.utils.IS_JETSON", "ultralytics.utils.IS_KAGGLE", "ultralytics.utils.IS_PIP_PACKAGE", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.MACOS", "ultralytics.utils.ONLINE", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.RKNN_CHIPS", "ultralytics.utils.ROOT", "ultralytics.utils.TORCHVISION_VERSION", "ultralytics.utils.USER_CONFIG_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.Retry", "ultralytics.utils.ThreadingLocked", "ultralytics.utils.TryExcept", "ultralytics.utils.clean_url", "ultralytics.utils.colorstr", "ultralytics.utils.downloads", "ultralytics.utils.is_github_action_running", "ultralytics.utils.url2file", "requests", "matplotlib.font_manager", "psutil", "ultralytics.utils.torch_utils.select_device", "psutil", "ultralytics.utils.ENVIRONMENT", "ultralytics.utils.torch_utils.get_cpu_info", "ultralytics.utils.torch_utils.get_gpu_info", "ultralytics.utils.torch_utils.autocast", "ultralytics.utils.torch_utils.get_cpu_info", "ultralytics.YOLO", "ultralytics.__version__", "IPython.display" ], "chunk_id": "function_git_describe_0ce7e654" }, { "content": "def print_args(args: Optional[dict] = None, show_file=True, show_func=False):\n \"\"\"\n Print function arguments (optional args dict).\n\n Args:\n args (dict, optional): Arguments to print.\n show_file (bool): Whether to show the file name.\n show_func (bool): Whether to show the function name.\n \"\"\"\n\n def strip_auth(v):\n \"\"\"Clean longer Ultralytics HUB URLs by stripping potential authentication information.\"\"\"\n return clean_url(v) if (isinstance(v, str) and v.startswith(\"http\") and len(v) > 100) else v\n\n x = inspect.currentframe().f_back # previous frame\n file, _, func, _, _ = inspect.getframeinfo(x)\n if args is None: # get args automatically\n args, _, _, frm = inspect.getargvalues(x)\n args = {k: v for k, v in frm.items() if k in args}\n try:\n file = Path(file).resolve().relative_to(ROOT).with_suffix(\"\")\n except ValueError:\n file = Path(file).stem\n s = (f\"{file}: \" if show_file else \"\") + (f\"{func}: \" if show_func else \"\")\n LOGGER.info(colorstr(s) + \", \".join(f\"{k}={strip_auth(v)}\" for k, v in sorted(args.items())))", "chunk_type": "function", "name": "print_args", "file_path": "ultralytics\\ultralytics\\utils\\checks.py", "start_line": 816, "end_line": 840, "start_col": 0, "end_col": 97, "parent_name": null, "docstring": "Print function arguments (optional args dict).\n\nArgs:\n args (dict, optional): Arguments to print.\n show_file (bool): Whether to show the file name.\n show_func (bool): Whether to show the function name.", "parameters": [ "args: Optional[dict]", "show_file", "show_func" ], "return_type": null, "decorators": [], "complexity_score": 5, "dependencies": [ "functools", "glob", "inspect", "math", "os", "platform", "re", "shutil", "subprocess", "time", "importlib.metadata", "pathlib.Path", "types.SimpleNamespace", "typing.Optional", "cv2", "numpy", "torch", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.AUTOINSTALL", "ultralytics.utils.IS_COLAB", "ultralytics.utils.IS_GIT_DIR", "ultralytics.utils.IS_JETSON", "ultralytics.utils.IS_KAGGLE", "ultralytics.utils.IS_PIP_PACKAGE", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.MACOS", "ultralytics.utils.ONLINE", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.RKNN_CHIPS", "ultralytics.utils.ROOT", "ultralytics.utils.TORCHVISION_VERSION", "ultralytics.utils.USER_CONFIG_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.Retry", "ultralytics.utils.ThreadingLocked", "ultralytics.utils.TryExcept", "ultralytics.utils.clean_url", "ultralytics.utils.colorstr", "ultralytics.utils.downloads", "ultralytics.utils.is_github_action_running", "ultralytics.utils.url2file", "requests", "matplotlib.font_manager", "psutil", "ultralytics.utils.torch_utils.select_device", "psutil", "ultralytics.utils.ENVIRONMENT", "ultralytics.utils.torch_utils.get_cpu_info", "ultralytics.utils.torch_utils.get_gpu_info", "ultralytics.utils.torch_utils.autocast", "ultralytics.utils.torch_utils.get_cpu_info", "ultralytics.YOLO", "ultralytics.__version__", "IPython.display" ], "chunk_id": "function_print_args_927eccde" }, { "content": "def cuda_device_count() -> int:\n \"\"\"\n Get the number of NVIDIA GPUs available in the environment.\n\n Returns:\n (int): The number of NVIDIA GPUs available.\n \"\"\"\n if IS_JETSON:\n # NVIDIA Jetson does not fully support nvidia-smi and therefore use PyTorch instead\n return torch.cuda.device_count()\n else:\n try:\n # Run the nvidia-smi command and capture its output\n output = subprocess.check_output(\n [\"nvidia-smi\", \"--query-gpu=count\", \"--format=csv,noheader,nounits\"], encoding=\"utf-8\"\n )\n\n # Take the first line and strip any leading/trailing white space\n first_line = output.strip().split(\"\\n\", 1)[0]\n\n return int(first_line)\n except (subprocess.CalledProcessError, FileNotFoundError, ValueError):\n # If the command fails, nvidia-smi is not found, or output is not an integer, assume no GPUs are available\n return 0", "chunk_type": "function", "name": "cuda_device_count", "file_path": "ultralytics\\ultralytics\\utils\\checks.py", "start_line": 843, "end_line": 866, "start_col": 0, "end_col": 20, "parent_name": null, "docstring": "Get the number of NVIDIA GPUs available in the environment.\n\nReturns:\n (int): The number of NVIDIA GPUs available.", "parameters": [], "return_type": "int", "decorators": [], "complexity_score": 3, "dependencies": [ "functools", "glob", "inspect", "math", "os", "platform", "re", "shutil", "subprocess", "time", "importlib.metadata", "pathlib.Path", "types.SimpleNamespace", "typing.Optional", "cv2", "numpy", "torch", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.AUTOINSTALL", "ultralytics.utils.IS_COLAB", "ultralytics.utils.IS_GIT_DIR", "ultralytics.utils.IS_JETSON", "ultralytics.utils.IS_KAGGLE", "ultralytics.utils.IS_PIP_PACKAGE", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.MACOS", "ultralytics.utils.ONLINE", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.RKNN_CHIPS", "ultralytics.utils.ROOT", "ultralytics.utils.TORCHVISION_VERSION", "ultralytics.utils.USER_CONFIG_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.Retry", "ultralytics.utils.ThreadingLocked", "ultralytics.utils.TryExcept", "ultralytics.utils.clean_url", "ultralytics.utils.colorstr", "ultralytics.utils.downloads", "ultralytics.utils.is_github_action_running", "ultralytics.utils.url2file", "requests", "matplotlib.font_manager", "psutil", "ultralytics.utils.torch_utils.select_device", "psutil", "ultralytics.utils.ENVIRONMENT", "ultralytics.utils.torch_utils.get_cpu_info", "ultralytics.utils.torch_utils.get_gpu_info", "ultralytics.utils.torch_utils.autocast", "ultralytics.utils.torch_utils.get_cpu_info", "ultralytics.YOLO", "ultralytics.__version__", "IPython.display" ], "chunk_id": "function_cuda_device_count_464b91d4" }, { "content": "def cuda_is_available() -> bool:\n \"\"\"\n Check if CUDA is available in the environment.\n\n Returns:\n (bool): True if one or more NVIDIA GPUs are available, False otherwise.\n \"\"\"\n return cuda_device_count() > 0", "chunk_type": "function", "name": "cuda_is_available", "file_path": "ultralytics\\ultralytics\\utils\\checks.py", "start_line": 869, "end_line": 876, "start_col": 0, "end_col": 34, "parent_name": null, "docstring": "Check if CUDA is available in the environment.\n\nReturns:\n (bool): True if one or more NVIDIA GPUs are available, False otherwise.", "parameters": [], "return_type": "bool", "decorators": [], "complexity_score": 1, "dependencies": [ "functools", "glob", "inspect", "math", "os", "platform", "re", "shutil", "subprocess", "time", "importlib.metadata", "pathlib.Path", "types.SimpleNamespace", "typing.Optional", "cv2", "numpy", "torch", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.AUTOINSTALL", "ultralytics.utils.IS_COLAB", "ultralytics.utils.IS_GIT_DIR", "ultralytics.utils.IS_JETSON", "ultralytics.utils.IS_KAGGLE", "ultralytics.utils.IS_PIP_PACKAGE", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.MACOS", "ultralytics.utils.ONLINE", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.RKNN_CHIPS", "ultralytics.utils.ROOT", "ultralytics.utils.TORCHVISION_VERSION", "ultralytics.utils.USER_CONFIG_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.Retry", "ultralytics.utils.ThreadingLocked", "ultralytics.utils.TryExcept", "ultralytics.utils.clean_url", "ultralytics.utils.colorstr", "ultralytics.utils.downloads", "ultralytics.utils.is_github_action_running", "ultralytics.utils.url2file", "requests", "matplotlib.font_manager", "psutil", "ultralytics.utils.torch_utils.select_device", "psutil", "ultralytics.utils.ENVIRONMENT", "ultralytics.utils.torch_utils.get_cpu_info", "ultralytics.utils.torch_utils.get_gpu_info", "ultralytics.utils.torch_utils.autocast", "ultralytics.utils.torch_utils.get_cpu_info", "ultralytics.YOLO", "ultralytics.__version__", "IPython.display" ], "chunk_id": "function_cuda_is_available_61f8d92d" }, { "content": "def is_rockchip():\n \"\"\"\n Check if the current environment is running on a Rockchip SoC.\n\n Returns:\n (bool): True if running on a Rockchip SoC, False otherwise.\n \"\"\"\n if LINUX and ARM64:\n try:\n with open(\"/proc/device-tree/compatible\") as f:\n dev_str = f.read()\n *_, soc = dev_str.split(\",\")\n if soc.replace(\"\\x00\", \"\") in RKNN_CHIPS:\n return True\n except OSError:\n return False\n else:\n return False", "chunk_type": "function", "name": "is_rockchip", "file_path": "ultralytics\\ultralytics\\utils\\checks.py", "start_line": 879, "end_line": 896, "start_col": 0, "end_col": 20, "parent_name": null, "docstring": "Check if the current environment is running on a Rockchip SoC.\n\nReturns:\n (bool): True if running on a Rockchip SoC, False otherwise.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 4, "dependencies": [ "functools", "glob", "inspect", "math", "os", "platform", "re", "shutil", "subprocess", "time", "importlib.metadata", "pathlib.Path", "types.SimpleNamespace", "typing.Optional", "cv2", "numpy", "torch", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.AUTOINSTALL", "ultralytics.utils.IS_COLAB", "ultralytics.utils.IS_GIT_DIR", "ultralytics.utils.IS_JETSON", "ultralytics.utils.IS_KAGGLE", "ultralytics.utils.IS_PIP_PACKAGE", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.MACOS", "ultralytics.utils.ONLINE", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.RKNN_CHIPS", "ultralytics.utils.ROOT", "ultralytics.utils.TORCHVISION_VERSION", "ultralytics.utils.USER_CONFIG_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.Retry", "ultralytics.utils.ThreadingLocked", "ultralytics.utils.TryExcept", "ultralytics.utils.clean_url", "ultralytics.utils.colorstr", "ultralytics.utils.downloads", "ultralytics.utils.is_github_action_running", "ultralytics.utils.url2file", "requests", "matplotlib.font_manager", "psutil", "ultralytics.utils.torch_utils.select_device", "psutil", "ultralytics.utils.ENVIRONMENT", "ultralytics.utils.torch_utils.get_cpu_info", "ultralytics.utils.torch_utils.get_gpu_info", "ultralytics.utils.torch_utils.autocast", "ultralytics.utils.torch_utils.get_cpu_info", "ultralytics.YOLO", "ultralytics.__version__", "IPython.display" ], "chunk_id": "function_is_rockchip_2fe49297" }, { "content": "def is_intel():\n \"\"\"\n Check if the system has Intel hardware (CPU or GPU).\n\n Returns:\n (bool): True if Intel hardware is detected, False otherwise.\n \"\"\"\n from ultralytics.utils.torch_utils import get_cpu_info\n\n # Check CPU\n if \"intel\" in get_cpu_info().lower():\n return True\n\n # Check GPU via xpu-smi\n try:\n result = subprocess.run([\"xpu-smi\", \"discovery\"], capture_output=True, text=True, timeout=5)\n return \"intel\" in result.stdout.lower()\n except (subprocess.TimeoutExpired, FileNotFoundError, subprocess.SubprocessError):\n return False", "chunk_type": "function", "name": "is_intel", "file_path": "ultralytics\\ultralytics\\utils\\checks.py", "start_line": 899, "end_line": 917, "start_col": 0, "end_col": 20, "parent_name": null, "docstring": "Check if the system has Intel hardware (CPU or GPU).\n\nReturns:\n (bool): True if Intel hardware is detected, False otherwise.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 3, "dependencies": [ "functools", "glob", "inspect", "math", "os", "platform", "re", "shutil", "subprocess", "time", "importlib.metadata", "pathlib.Path", "types.SimpleNamespace", "typing.Optional", "cv2", "numpy", "torch", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.AUTOINSTALL", "ultralytics.utils.IS_COLAB", "ultralytics.utils.IS_GIT_DIR", "ultralytics.utils.IS_JETSON", "ultralytics.utils.IS_KAGGLE", "ultralytics.utils.IS_PIP_PACKAGE", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.MACOS", "ultralytics.utils.ONLINE", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.RKNN_CHIPS", "ultralytics.utils.ROOT", "ultralytics.utils.TORCHVISION_VERSION", "ultralytics.utils.USER_CONFIG_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.Retry", "ultralytics.utils.ThreadingLocked", "ultralytics.utils.TryExcept", "ultralytics.utils.clean_url", "ultralytics.utils.colorstr", "ultralytics.utils.downloads", "ultralytics.utils.is_github_action_running", "ultralytics.utils.url2file", "requests", "matplotlib.font_manager", "psutil", "ultralytics.utils.torch_utils.select_device", "psutil", "ultralytics.utils.ENVIRONMENT", "ultralytics.utils.torch_utils.get_cpu_info", "ultralytics.utils.torch_utils.get_gpu_info", "ultralytics.utils.torch_utils.autocast", "ultralytics.utils.torch_utils.get_cpu_info", "ultralytics.YOLO", "ultralytics.__version__", "IPython.display" ], "chunk_id": "function_is_intel_375f87f7" }, { "content": "def is_sudo_available() -> bool:\n \"\"\"\n Check if the sudo command is available in the environment.\n\n Returns:\n (bool): True if the sudo command is available, False otherwise.\n \"\"\"\n if WINDOWS:\n return False\n cmd = \"sudo --version\"\n return subprocess.run(cmd, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL).returncode == 0", "chunk_type": "function", "name": "is_sudo_available", "file_path": "ultralytics\\ultralytics\\utils\\checks.py", "start_line": 920, "end_line": 930, "start_col": 0, "end_col": 112, "parent_name": null, "docstring": "Check if the sudo command is available in the environment.\n\nReturns:\n (bool): True if the sudo command is available, False otherwise.", "parameters": [], "return_type": "bool", "decorators": [], "complexity_score": 2, "dependencies": [ "functools", "glob", "inspect", "math", "os", "platform", "re", "shutil", "subprocess", "time", "importlib.metadata", "pathlib.Path", "types.SimpleNamespace", "typing.Optional", "cv2", "numpy", "torch", "ultralytics.utils.ARM64", "ultralytics.utils.ASSETS", "ultralytics.utils.AUTOINSTALL", "ultralytics.utils.IS_COLAB", "ultralytics.utils.IS_GIT_DIR", "ultralytics.utils.IS_JETSON", "ultralytics.utils.IS_KAGGLE", "ultralytics.utils.IS_PIP_PACKAGE", "ultralytics.utils.LINUX", "ultralytics.utils.LOGGER", "ultralytics.utils.MACOS", "ultralytics.utils.ONLINE", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.RKNN_CHIPS", "ultralytics.utils.ROOT", "ultralytics.utils.TORCHVISION_VERSION", "ultralytics.utils.USER_CONFIG_DIR", "ultralytics.utils.WINDOWS", "ultralytics.utils.Retry", "ultralytics.utils.ThreadingLocked", "ultralytics.utils.TryExcept", "ultralytics.utils.clean_url", "ultralytics.utils.colorstr", "ultralytics.utils.downloads", "ultralytics.utils.is_github_action_running", "ultralytics.utils.url2file", "requests", "matplotlib.font_manager", "psutil", "ultralytics.utils.torch_utils.select_device", "psutil", "ultralytics.utils.ENVIRONMENT", "ultralytics.utils.torch_utils.get_cpu_info", "ultralytics.utils.torch_utils.get_gpu_info", "ultralytics.utils.torch_utils.autocast", "ultralytics.utils.torch_utils.get_cpu_info", "ultralytics.YOLO", "ultralytics.__version__", "IPython.display" ], "chunk_id": "function_is_sudo_available_7bc7ce4d" }, { "content": "IS_PYTHON_3_8 = PYTHON_VERSION.startswith(\"3.8\")", "chunk_type": "variable", "name": "IS_PYTHON_3_8", "file_path": "ultralytics\\ultralytics\\utils\\checks.py", "start_line": 938, "end_line": 938, "start_col": 0, "end_col": 48, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_IS_PYTHON_3_8_17224886" }, { "content": "IS_PYTHON_3_12 = PYTHON_VERSION.startswith(\"3.12\")", "chunk_type": "variable", "name": "IS_PYTHON_3_12", "file_path": "ultralytics\\ultralytics\\utils\\checks.py", "start_line": 939, "end_line": 939, "start_col": 0, "end_col": 50, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_IS_PYTHON_3_12_be9b4c7b" }, { "content": "IS_PYTHON_3_13 = PYTHON_VERSION.startswith(\"3.13\")", "chunk_type": "variable", "name": "IS_PYTHON_3_13", "file_path": "ultralytics\\ultralytics\\utils\\checks.py", "start_line": 940, "end_line": 940, "start_col": 0, "end_col": 50, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_IS_PYTHON_3_13_915d7a72" }, { "content": "IS_PYTHON_MINIMUM_3_10 = check_python(\"3.10\", hard=False)", "chunk_type": "variable", "name": "IS_PYTHON_MINIMUM_3_10", "file_path": "ultralytics\\ultralytics\\utils\\checks.py", "start_line": 942, "end_line": 942, "start_col": 0, "end_col": 57, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_IS_PYTHON_MINIMUM_3_10_0cceec0b" }, { "content": "IS_PYTHON_MINIMUM_3_12 = check_python(\"3.12\", hard=False)", "chunk_type": "variable", "name": "IS_PYTHON_MINIMUM_3_12", "file_path": "ultralytics\\ultralytics\\utils\\checks.py", "start_line": 943, "end_line": 943, "start_col": 0, "end_col": 57, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_IS_PYTHON_MINIMUM_3_12_26a33f45" }, { "content": "import os", "chunk_type": "import", "name": "os", "file_path": "ultralytics\\ultralytics\\utils\\dist.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 9, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_os_1fbc79e0" }, { "content": "import shutil", "chunk_type": "import", "name": "shutil", "file_path": "ultralytics\\ultralytics\\utils\\dist.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 13, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_shutil_8a4c838b" }, { "content": "import sys", "chunk_type": "import", "name": "sys", "file_path": "ultralytics\\ultralytics\\utils\\dist.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 10, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_sys_28c13f74" }, { "content": "import tempfile", "chunk_type": "import", "name": "tempfile", "file_path": "ultralytics\\ultralytics\\utils\\dist.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 15, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_tempfile_20828fc4" }, { "content": "from . import USER_CONFIG_DIR", "chunk_type": "import", "name": "USER_CONFIG_DIR", "file_path": "ultralytics\\ultralytics\\utils\\dist.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 29, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_USER_CONFIG_DIR_d2b6230b" }, { "content": "from .torch_utils import TORCH_1_9", "chunk_type": "import", "name": "TORCH_1_9", "file_path": "ultralytics\\ultralytics\\utils\\dist.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 34, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_TORCH_1_9_b327005c" }, { "content": "def find_free_network_port() -> int:\n \"\"\"\n Find a free port on localhost.\n\n It is useful in single-node training when we don't want to connect to a real main node but have to set the\n `MASTER_PORT` environment variable.\n\n Returns:\n (int): The available network port number.\n \"\"\"\n import socket\n\n with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:\n s.bind((\"127.0.0.1\", 0))\n return s.getsockname()[1] # port", "chunk_type": "function", "name": "find_free_network_port", "file_path": "ultralytics\\ultralytics\\utils\\dist.py", "start_line": 12, "end_line": 26, "start_col": 0, "end_col": 33, "parent_name": null, "docstring": "Find a free port on localhost.\n\nIt is useful in single-node training when we don't want to connect to a real main node but have to set the\n`MASTER_PORT` environment variable.\n\nReturns:\n (int): The available network port number.", "parameters": [], "return_type": "int", "decorators": [], "complexity_score": 1, "dependencies": [ "os", "shutil", "sys", "tempfile", "USER_CONFIG_DIR", "torch_utils.TORCH_1_9", "socket", "__main__" ], "chunk_id": "function_find_free_network_port_f7ebaaa5" }, { "content": "def generate_ddp_file(trainer):\n \"\"\"\n Generate a DDP (Distributed Data Parallel) file for multi-GPU training.\n\n This function creates a temporary Python file that enables distributed training across multiple GPUs.\n The file contains the necessary configuration to initialize the trainer in a distributed environment.\n\n Args:\n trainer (ultralytics.engine.trainer.BaseTrainer): The trainer containing training configuration and arguments.\n Must have args attribute and be a class instance.\n\n Returns:\n (str): Path to the generated temporary DDP file.\n\n Notes:\n The generated file is saved in the USER_CONFIG_DIR/DDP directory and includes:\n - Trainer class import\n - Configuration overrides from the trainer arguments\n - Model path configuration\n - Training initialization code\n \"\"\"\n module, name = f\"{trainer.__class__.__module__}.{trainer.__class__.__name__}\".rsplit(\".\", 1)\n\n content = f\"\"\"\n# Ultralytics Multi-GPU training temp file (should be automatically deleted after use)\noverrides = {vars(trainer.args)}\n\nif __name__ == \"__main__\":\n from {module} import {name}\n from ultralytics.utils import DEFAULT_CFG_DICT\n\n cfg = DEFAULT_CFG_DICT.copy()\n cfg.update(save_dir='') # handle the extra key 'save_dir'\n trainer = {name}(cfg=cfg, overrides=overrides)\n trainer.args.model = \"{getattr(trainer.hub_session, \"model_url\", trainer.args.model)}\"\n results = trainer.train()\n\"\"\"\n (USER_CONFIG_DIR / \"DDP\").mkdir(exist_ok=True)\n with tempfile.NamedTemporaryFile(\n prefix=\"_temp_\",\n suffix=f\"{id(trainer)}.py\",\n mode=\"w+\",\n encoding=\"utf-8\",\n dir=USER_CONFIG_DIR / \"DDP\",\n delete=False,\n ) as file:\n file.write(content)\n return file.name", "chunk_type": "function", "name": "generate_ddp_file", "file_path": "ultralytics\\ultralytics\\utils\\dist.py", "start_line": 29, "end_line": 76, "start_col": 0, "end_col": 20, "parent_name": null, "docstring": "Generate a DDP (Distributed Data Parallel) file for multi-GPU training.\n\nThis function creates a temporary Python file that enables distributed training across multiple GPUs.\nThe file contains the necessary configuration to initialize the trainer in a distributed environment.\n\nArgs:\n trainer (ultralytics.engine.trainer.BaseTrainer): The trainer containing training configuration and arguments.\n Must have args attribute and be a class instance.\n\nReturns:\n (str): Path to the generated temporary DDP file.\n\nNotes:\n The generated file is saved in the USER_CONFIG_DIR/DDP directory and includes:\n - Trainer class import\n - Configuration overrides from the trainer arguments\n - Model path configuration\n - Training initialization code", "parameters": [ "trainer" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "os", "shutil", "sys", "tempfile", "USER_CONFIG_DIR", "torch_utils.TORCH_1_9", "socket", "__main__" ], "chunk_id": "function_generate_ddp_file_2558acd0" }, { "content": "def generate_ddp_command(world_size: int, trainer):\n \"\"\"\n Generate command for distributed training.\n\n Args:\n world_size (int): Number of processes to spawn for distributed training.\n trainer (ultralytics.engine.trainer.BaseTrainer): The trainer containing configuration for distributed training.\n\n Returns:\n cmd (List[str]): The command to execute for distributed training.\n file (str): Path to the temporary file created for DDP training.\n \"\"\"\n import __main__ # noqa local import to avoid https://github.com/Lightning-AI/pytorch-lightning/issues/15218\n\n if not trainer.resume:\n shutil.rmtree(trainer.save_dir) # remove the save_dir\n file = generate_ddp_file(trainer)\n dist_cmd = \"torch.distributed.run\" if TORCH_1_9 else \"torch.distributed.launch\"\n port = find_free_network_port()\n cmd = [sys.executable, \"-m\", dist_cmd, \"--nproc_per_node\", f\"{world_size}\", \"--master_port\", f\"{port}\", file]\n return cmd, file", "chunk_type": "function", "name": "generate_ddp_command", "file_path": "ultralytics\\ultralytics\\utils\\dist.py", "start_line": 79, "end_line": 99, "start_col": 0, "end_col": 20, "parent_name": null, "docstring": "Generate command for distributed training.\n\nArgs:\n world_size (int): Number of processes to spawn for distributed training.\n trainer (ultralytics.engine.trainer.BaseTrainer): The trainer containing configuration for distributed training.\n\nReturns:\n cmd (List[str]): The command to execute for distributed training.\n file (str): Path to the temporary file created for DDP training.", "parameters": [ "world_size: int", "trainer" ], "return_type": null, "decorators": [], "complexity_score": 2, "dependencies": [ "os", "shutil", "sys", "tempfile", "USER_CONFIG_DIR", "torch_utils.TORCH_1_9", "socket", "__main__" ], "chunk_id": "function_generate_ddp_command_377d6fdd" }, { "content": "def ddp_cleanup(trainer, file):\n \"\"\"\n Delete temporary file if created during distributed data parallel (DDP) training.\n\n This function checks if the provided file contains the trainer's ID in its name, indicating it was created\n as a temporary file for DDP training, and deletes it if so.\n\n Args:\n trainer (ultralytics.engine.trainer.BaseTrainer): The trainer used for distributed training.\n file (str): Path to the file that might need to be deleted.\n\n Examples:\n >>> trainer = YOLOTrainer()\n >>> file = \"/tmp/ddp_temp_123456789.py\"\n >>> ddp_cleanup(trainer, file)\n \"\"\"\n if f\"{id(trainer)}.py\" in file: # if temp_file suffix in file\n os.remove(file)", "chunk_type": "function", "name": "ddp_cleanup", "file_path": "ultralytics\\ultralytics\\utils\\dist.py", "start_line": 102, "end_line": 119, "start_col": 0, "end_col": 23, "parent_name": null, "docstring": "Delete temporary file if created during distributed data parallel (DDP) training.\n\nThis function checks if the provided file contains the trainer's ID in its name, indicating it was created\nas a temporary file for DDP training, and deletes it if so.\n\nArgs:\n trainer (ultralytics.engine.trainer.BaseTrainer): The trainer used for distributed training.\n file (str): Path to the file that might need to be deleted.\n\nExamples:\n >>> trainer = YOLOTrainer()\n >>> file = \"/tmp/ddp_temp_123456789.py\"\n >>> ddp_cleanup(trainer, file)", "parameters": [ "trainer", "file" ], "return_type": null, "decorators": [], "complexity_score": 2, "dependencies": [ "os", "shutil", "sys", "tempfile", "USER_CONFIG_DIR", "torch_utils.TORCH_1_9", "socket", "__main__" ], "chunk_id": "function_ddp_cleanup_d6aa6b54" }, { "content": "import re", "chunk_type": "import", "name": "re", "file_path": "ultralytics\\ultralytics\\utils\\downloads.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 9, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_re_69b018a7" }, { "content": "import shutil", "chunk_type": "import", "name": "shutil", "file_path": "ultralytics\\ultralytics\\utils\\downloads.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 13, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_shutil_fd555cdc" }, { "content": "import subprocess", "chunk_type": "import", "name": "subprocess", "file_path": "ultralytics\\ultralytics\\utils\\downloads.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 17, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_subprocess_de40055c" }, { "content": "from itertools import repeat", "chunk_type": "import", "name": "repeat", "file_path": "ultralytics\\ultralytics\\utils\\downloads.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 28, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_repeat_380fc23f" }, { "content": "from multiprocessing.pool import ThreadPool", "chunk_type": "import", "name": "ThreadPool", "file_path": "ultralytics\\ultralytics\\utils\\downloads.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 43, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_ThreadPool_b6d1e36d" }, { "content": "from pathlib import Path", "chunk_type": "import", "name": "Path", "file_path": "ultralytics\\ultralytics\\utils\\downloads.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Path_7ae7dad5" }, { "content": "from typing import List, Tuple", "chunk_type": "import", "name": "List, Tuple", "file_path": "ultralytics\\ultralytics\\utils\\downloads.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 30, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_List, Tuple_328d12b9" }, { "content": "from urllib import parse, request", "chunk_type": "import", "name": "parse, request", "file_path": "ultralytics\\ultralytics\\utils\\downloads.py", "start_line": 10, "end_line": 10, "start_col": 0, "end_col": 33, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_parse, request_36875657" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\ultralytics\\utils\\downloads.py", "start_line": 12, "end_line": 12, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_364b3d6b" }, { "content": "from ultralytics.utils import LOGGER, TQDM, checks, clean_url, emojis, is_online, url2file", "chunk_type": "import", "name": "LOGGER, TQDM, checks, clean_url, emojis, is_online, url2file", "file_path": "ultralytics\\ultralytics\\utils\\downloads.py", "start_line": 14, "end_line": 14, "start_col": 0, "end_col": 90, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_LOGGER, TQDM, checks, clean_url, emojis, is_online, url2file_4cd4534a" }, { "content": "GITHUB_ASSETS_REPO = \"ultralytics/assets\"", "chunk_type": "variable", "name": "GITHUB_ASSETS_REPO", "file_path": "ultralytics\\ultralytics\\utils\\downloads.py", "start_line": 17, "end_line": 17, "start_col": 0, "end_col": 41, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_GITHUB_ASSETS_REPO_74b554b4" }, { "content": "GITHUB_ASSETS_NAMES = frozenset(\n [f\"yolov8{k}{suffix}.pt\" for k in \"nsmlx\" for suffix in (\"\", \"-cls\", \"-seg\", \"-pose\", \"-obb\", \"-oiv7\")]\n + [f\"yolo11{k}{suffix}.pt\" for k in \"nsmlx\" for suffix in (\"\", \"-cls\", \"-seg\", \"-pose\", \"-obb\")]\n + [f\"yolo12{k}{suffix}.pt\" for k in \"nsmlx\" for suffix in (\"\",)] # detect models only currently\n + [f\"yolov5{k}{resolution}u.pt\" for k in \"nsmlx\" for resolution in (\"\", \"6\")]\n + [f\"yolov3{k}u.pt\" for k in (\"\", \"-spp\", \"-tiny\")]\n + [f\"yolov8{k}-world.pt\" for k in \"smlx\"]\n + [f\"yolov8{k}-worldv2.pt\" for k in \"smlx\"]\n + [f\"yoloe-v8{k}{suffix}.pt\" for k in \"sml\" for suffix in (\"-seg\", \"-seg-pf\")]\n + [f\"yoloe-11{k}{suffix}.pt\" for k in \"sml\" for suffix in (\"-seg\", \"-seg-pf\")]\n + [f\"yolov9{k}.pt\" for k in \"tsmce\"]\n + [f\"yolov10{k}.pt\" for k in \"nsmblx\"]\n + [f\"yolo_nas_{k}.pt\" for k in \"sml\"]\n + [f\"sam_{k}.pt\" for k in \"bl\"]\n + [f\"sam2_{k}.pt\" for k in \"blst\"]\n + [f\"sam2.1_{k}.pt\" for k in \"blst\"]\n + [f\"FastSAM-{k}.pt\" for k in \"sx\"]\n + [f\"rtdetr-{k}.pt\" for k in \"lx\"]\n + [\n \"mobile_sam.pt\",\n \"mobileclip_blt.ts\",\n \"yolo11n-grayscale.pt\",\n \"calibration_image_sample_data_20x128x128x3_float32.npy.zip\",\n ]\n)", "chunk_type": "variable", "name": "GITHUB_ASSETS_NAMES", "file_path": "ultralytics\\ultralytics\\utils\\downloads.py", "start_line": 18, "end_line": 42, "start_col": 0, "end_col": 1, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_GITHUB_ASSETS_NAMES_ba359a8b" }, { "content": "GITHUB_ASSETS_STEMS = frozenset(k.rpartition(\".\")[0] for k in GITHUB_ASSETS_NAMES)", "chunk_type": "variable", "name": "GITHUB_ASSETS_STEMS", "file_path": "ultralytics\\ultralytics\\utils\\downloads.py", "start_line": 43, "end_line": 43, "start_col": 0, "end_col": 82, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_GITHUB_ASSETS_STEMS_97b36524" }, { "content": "def is_url(url, check: bool = False) -> bool:\n \"\"\"\n Validate if the given string is a URL and optionally check if the URL exists online.\n\n Args:\n url (str): The string to be validated as a URL.\n check (bool, optional): If True, performs an additional check to see if the URL exists online.\n\n Returns:\n (bool): True for a valid URL. If 'check' is True, also returns True if the URL exists online.\n\n Examples:\n >>> valid = is_url(\"https://www.example.com\")\n >>> valid_and_exists = is_url(\"https://www.example.com\", check=True)\n \"\"\"\n try:\n url = str(url)\n result = parse.urlparse(url)\n assert all([result.scheme, result.netloc]) # check if is url\n if check:\n with request.urlopen(url) as response:\n return response.getcode() == 200 # check if exists online\n return True\n except Exception:\n return False", "chunk_type": "function", "name": "is_url", "file_path": "ultralytics\\ultralytics\\utils\\downloads.py", "start_line": 46, "end_line": 70, "start_col": 0, "end_col": 20, "parent_name": null, "docstring": "Validate if the given string is a URL and optionally check if the URL exists online.\n\nArgs:\n url (str): The string to be validated as a URL.\n check (bool, optional): If True, performs an additional check to see if the URL exists online.\n\nReturns:\n (bool): True for a valid URL. If 'check' is True, also returns True if the URL exists online.\n\nExamples:\n >>> valid = is_url(\"https://www.example.com\")\n >>> valid_and_exists = is_url(\"https://www.example.com\", check=True)", "parameters": [ "url", "check: bool" ], "return_type": "bool", "decorators": [], "complexity_score": 3, "dependencies": [ "re", "shutil", "subprocess", "itertools.repeat", "multiprocessing.pool.ThreadPool", "pathlib.Path", "typing.List", "typing.Tuple", "urllib.parse", "urllib.request", "torch", "ultralytics.utils.LOGGER", "ultralytics.utils.TQDM", "ultralytics.utils.checks", "ultralytics.utils.clean_url", "ultralytics.utils.emojis", "ultralytics.utils.is_online", "ultralytics.utils.url2file", "zipfile.ZIP_DEFLATED", "zipfile.ZIP_STORED", "zipfile.ZipFile", "zipfile.BadZipFile", "zipfile.ZipFile", "zipfile.is_zipfile", "requests", "requests", "requests", "ultralytics.utils.SETTINGS", "zipfile.is_zipfile" ], "chunk_id": "function_is_url_791f4eb3" }, { "content": "def delete_dsstore(path, files_to_delete=(\".DS_Store\", \"__MACOSX\")):\n \"\"\"\n Delete all specified system files in a directory.\n\n Args:\n path (str | Path): The directory path where the files should be deleted.\n files_to_delete (tuple): The files to be deleted.\n\n Examples:\n >>> from ultralytics.utils.downloads import delete_dsstore\n >>> delete_dsstore(\"path/to/dir\")\n\n Notes:\n \".DS_store\" files are created by the Apple operating system and contain metadata about folders and files. They\n are hidden system files and can cause issues when transferring files between different operating systems.\n \"\"\"\n for file in files_to_delete:\n matches = list(Path(path).rglob(file))\n LOGGER.info(f\"Deleting {file} files: {matches}\")\n for f in matches:\n f.unlink()", "chunk_type": "function", "name": "delete_dsstore", "file_path": "ultralytics\\ultralytics\\utils\\downloads.py", "start_line": 73, "end_line": 93, "start_col": 0, "end_col": 22, "parent_name": null, "docstring": "Delete all specified system files in a directory.\n\nArgs:\n path (str | Path): The directory path where the files should be deleted.\n files_to_delete (tuple): The files to be deleted.\n\nExamples:\n >>> from ultralytics.utils.downloads import delete_dsstore\n >>> delete_dsstore(\"path/to/dir\")\n\nNotes:\n \".DS_store\" files are created by the Apple operating system and contain metadata about folders and files. They\n are hidden system files and can cause issues when transferring files between different operating systems.", "parameters": [ "path", "files_to_delete" ], "return_type": null, "decorators": [], "complexity_score": 3, "dependencies": [ "re", "shutil", "subprocess", "itertools.repeat", "multiprocessing.pool.ThreadPool", "pathlib.Path", "typing.List", "typing.Tuple", "urllib.parse", "urllib.request", "torch", "ultralytics.utils.LOGGER", "ultralytics.utils.TQDM", "ultralytics.utils.checks", "ultralytics.utils.clean_url", "ultralytics.utils.emojis", "ultralytics.utils.is_online", "ultralytics.utils.url2file", "zipfile.ZIP_DEFLATED", "zipfile.ZIP_STORED", "zipfile.ZipFile", "zipfile.BadZipFile", "zipfile.ZipFile", "zipfile.is_zipfile", "requests", "requests", "requests", "ultralytics.utils.SETTINGS", "zipfile.is_zipfile" ], "chunk_id": "function_delete_dsstore_cfe15d33" }, { "content": "def zip_directory(directory, compress: bool = True, exclude=(\".DS_Store\", \"__MACOSX\"), progress: bool = True) -> Path:\n \"\"\"\n Zip the contents of a directory, excluding specified files.\n\n The resulting zip file is named after the directory and placed alongside it.\n\n Args:\n directory (str | Path): The path to the directory to be zipped.\n compress (bool): Whether to compress the files while zipping.\n exclude (tuple, optional): A tuple of filename strings to be excluded.\n progress (bool, optional): Whether to display a progress bar.\n\n Returns:\n (Path): The path to the resulting zip file.\n\n Examples:\n >>> from ultralytics.utils.downloads import zip_directory\n >>> file = zip_directory(\"path/to/dir\")\n \"\"\"\n from zipfile import ZIP_DEFLATED, ZIP_STORED, ZipFile\n\n delete_dsstore(directory)\n directory = Path(directory)\n if not directory.is_dir():\n raise FileNotFoundError(f\"Directory '{directory}' does not exist.\")\n\n # Zip with progress bar\n files_to_zip = [f for f in directory.rglob(\"*\") if f.is_file() and all(x not in f.name for x in exclude)]\n zip_file = directory.with_suffix(\".zip\")\n compression = ZIP_DEFLATED if compress else ZIP_STORED\n with ZipFile(zip_file, \"w\", compression) as f:\n for file in TQDM(files_to_zip, desc=f\"Zipping {directory} to {zip_file}...\", unit=\"file\", disable=not progress):\n f.write(file, file.relative_to(directory))\n\n return zip_file # return path to zip file", "chunk_type": "function", "name": "zip_directory", "file_path": "ultralytics\\ultralytics\\utils\\downloads.py", "start_line": 96, "end_line": 130, "start_col": 0, "end_col": 19, "parent_name": null, "docstring": "Zip the contents of a directory, excluding specified files.\n\nThe resulting zip file is named after the directory and placed alongside it.\n\nArgs:\n directory (str | Path): The path to the directory to be zipped.\n compress (bool): Whether to compress the files while zipping.\n exclude (tuple, optional): A tuple of filename strings to be excluded.\n progress (bool, optional): Whether to display a progress bar.\n\nReturns:\n (Path): The path to the resulting zip file.\n\nExamples:\n >>> from ultralytics.utils.downloads import zip_directory\n >>> file = zip_directory(\"path/to/dir\")", "parameters": [ "directory", "compress: bool", "exclude", "progress: bool" ], "return_type": "Path", "decorators": [], "complexity_score": 5, "dependencies": [ "re", "shutil", "subprocess", "itertools.repeat", "multiprocessing.pool.ThreadPool", "pathlib.Path", "typing.List", "typing.Tuple", "urllib.parse", "urllib.request", "torch", "ultralytics.utils.LOGGER", "ultralytics.utils.TQDM", "ultralytics.utils.checks", "ultralytics.utils.clean_url", "ultralytics.utils.emojis", "ultralytics.utils.is_online", "ultralytics.utils.url2file", "zipfile.ZIP_DEFLATED", "zipfile.ZIP_STORED", "zipfile.ZipFile", "zipfile.BadZipFile", "zipfile.ZipFile", "zipfile.is_zipfile", "requests", "requests", "requests", "ultralytics.utils.SETTINGS", "zipfile.is_zipfile" ], "chunk_id": "function_zip_directory_d38bafc2" }, { "content": "def unzip_file(\n file,\n path=None,\n exclude=(\".DS_Store\", \"__MACOSX\"),\n exist_ok: bool = False,\n progress: bool = True,\n) -> Path:\n \"\"\"\n Unzip a *.zip file to the specified path, excluding specified files.\n\n If the zipfile does not contain a single top-level directory, the function will create a new\n directory with the same name as the zipfile (without the extension) to extract its contents.\n If a path is not provided, the function will use the parent directory of the zipfile as the default path.\n\n Args:\n file (str | Path): The path to the zipfile to be extracted.\n path (str | Path, optional): The path to extract the zipfile to.\n exclude (tuple, optional): A tuple of filename strings to be excluded.\n exist_ok (bool, optional): Whether to overwrite existing contents if they exist.\n progress (bool, optional): Whether to display a progress bar.\n\n Returns:\n (Path): The path to the directory where the zipfile was extracted.\n\n Raises:\n BadZipFile: If the provided file does not exist or is not a valid zipfile.\n\n Examples:\n >>> from ultralytics.utils.downloads import unzip_file\n >>> directory = unzip_file(\"path/to/file.zip\")\n \"\"\"\n from zipfile import BadZipFile, ZipFile, is_zipfile\n\n if not (Path(file).exists() and is_zipfile(file)):\n raise BadZipFile(f\"File '{file}' does not exist or is a bad zip file.\")\n if path is None:\n path = Path(file).parent # default path\n\n # Unzip the file contents\n with ZipFile(file) as zipObj:\n files = [f for f in zipObj.namelist() if all(x not in f for x in exclude)]\n top_level_dirs = {Path(f).parts[0] for f in files}\n\n # Decide to unzip directly or unzip into a directory\n unzip_as_dir = len(top_level_dirs) == 1 # (len(files) > 1 and not files[0].endswith(\"/\"))\n if unzip_as_dir:\n # Zip has 1 top-level directory\n extract_path = path # i.e. ../datasets\n path = Path(path) / list(top_level_dirs)[0] # i.e. extract coco8/ dir to ../datasets/\n else:\n # Zip has multiple files at top level\n path = extract_path = Path(path) / Path(file).stem # i.e. extract multiple files to ../datasets/coco8/\n\n # Check if destination directory already exists and contains files\n if path.exists() and any(path.iterdir()) and not exist_ok:\n # If it exists and is not empty, return the path without unzipping\n LOGGER.warning(f\"Skipping {file} unzip as destination directory {path} is not empty.\")\n return path\n\n for f in TQDM(files, desc=f\"Unzipping {file} to {Path(path).resolve()}...\", unit=\"file\", disable=not progress):\n # Ensure the file is within the extract_path to avoid path traversal security vulnerability\n if \"..\" in Path(f).parts:\n LOGGER.warning(f\"Potentially insecure file path: {f}, skipping extraction.\")\n continue\n zipObj.extract(f, extract_path)\n\n return path # return unzip dir", "chunk_type": "function", "name": "unzip_file", "file_path": "ultralytics\\ultralytics\\utils\\downloads.py", "start_line": 133, "end_line": 199, "start_col": 0, "end_col": 15, "parent_name": null, "docstring": "Unzip a *.zip file to the specified path, excluding specified files.\n\nIf the zipfile does not contain a single top-level directory, the function will create a new\ndirectory with the same name as the zipfile (without the extension) to extract its contents.\nIf a path is not provided, the function will use the parent directory of the zipfile as the default path.\n\nArgs:\n file (str | Path): The path to the zipfile to be extracted.\n path (str | Path, optional): The path to extract the zipfile to.\n exclude (tuple, optional): A tuple of filename strings to be excluded.\n exist_ok (bool, optional): Whether to overwrite existing contents if they exist.\n progress (bool, optional): Whether to display a progress bar.\n\nReturns:\n (Path): The path to the directory where the zipfile was extracted.\n\nRaises:\n BadZipFile: If the provided file does not exist or is not a valid zipfile.\n\nExamples:\n >>> from ultralytics.utils.downloads import unzip_file\n >>> directory = unzip_file(\"path/to/file.zip\")", "parameters": [ "file", "path", "exclude", "exist_ok: bool", "progress: bool" ], "return_type": "Path", "decorators": [], "complexity_score": 10, "dependencies": [ "re", "shutil", "subprocess", "itertools.repeat", "multiprocessing.pool.ThreadPool", "pathlib.Path", "typing.List", "typing.Tuple", "urllib.parse", "urllib.request", "torch", "ultralytics.utils.LOGGER", "ultralytics.utils.TQDM", "ultralytics.utils.checks", "ultralytics.utils.clean_url", "ultralytics.utils.emojis", "ultralytics.utils.is_online", "ultralytics.utils.url2file", "zipfile.ZIP_DEFLATED", "zipfile.ZIP_STORED", "zipfile.ZipFile", "zipfile.BadZipFile", "zipfile.ZipFile", "zipfile.is_zipfile", "requests", "requests", "requests", "ultralytics.utils.SETTINGS", "zipfile.is_zipfile" ], "chunk_id": "function_unzip_file_941bdc4b" }, { "content": "def check_disk_space(\n url: str = \"https://ultralytics.com/assets/coco8.zip\",\n path=Path.cwd(),\n sf: float = 1.5,\n hard: bool = True,\n) -> bool:\n \"\"\"\n Check if there is sufficient disk space to download and store a file.\n\n Args:\n url (str, optional): The URL to the file.\n path (str | Path, optional): The path or drive to check the available free space on.\n sf (float, optional): Safety factor, the multiplier for the required free space.\n hard (bool, optional): Whether to throw an error or not on insufficient disk space.\n\n Returns:\n (bool): True if there is sufficient disk space, False otherwise.\n \"\"\"\n import requests # slow import\n\n try:\n r = requests.head(url) # response\n assert r.status_code < 400, f\"URL error for {url}: {r.status_code} {r.reason}\" # check response\n except Exception:\n return True # requests issue, default to True\n\n # Check file size\n gib = 1 << 30 # bytes per GiB\n data = int(r.headers.get(\"Content-Length\", 0)) / gib # file size (GB)\n total, used, free = (x / gib for x in shutil.disk_usage(path)) # bytes\n\n if data * sf < free:\n return True # sufficient space\n\n # Insufficient space\n text = (\n f\"Insufficient free disk space {free:.1f} GB < {data * sf:.3f} GB required, \"\n f\"Please free {data * sf - free:.1f} GB additional disk space and try again.\"\n )\n if hard:\n raise MemoryError(text)\n LOGGER.warning(text)\n return False", "chunk_type": "function", "name": "check_disk_space", "file_path": "ultralytics\\ultralytics\\utils\\downloads.py", "start_line": 202, "end_line": 244, "start_col": 0, "end_col": 16, "parent_name": null, "docstring": "Check if there is sufficient disk space to download and store a file.\n\nArgs:\n url (str, optional): The URL to the file.\n path (str | Path, optional): The path or drive to check the available free space on.\n sf (float, optional): Safety factor, the multiplier for the required free space.\n hard (bool, optional): Whether to throw an error or not on insufficient disk space.\n\nReturns:\n (bool): True if there is sufficient disk space, False otherwise.", "parameters": [ "url: str", "path", "sf: float", "hard: bool" ], "return_type": "bool", "decorators": [], "complexity_score": 5, "dependencies": [ "re", "shutil", "subprocess", "itertools.repeat", "multiprocessing.pool.ThreadPool", "pathlib.Path", "typing.List", "typing.Tuple", "urllib.parse", "urllib.request", "torch", "ultralytics.utils.LOGGER", "ultralytics.utils.TQDM", "ultralytics.utils.checks", "ultralytics.utils.clean_url", "ultralytics.utils.emojis", "ultralytics.utils.is_online", "ultralytics.utils.url2file", "zipfile.ZIP_DEFLATED", "zipfile.ZIP_STORED", "zipfile.ZipFile", "zipfile.BadZipFile", "zipfile.ZipFile", "zipfile.is_zipfile", "requests", "requests", "requests", "ultralytics.utils.SETTINGS", "zipfile.is_zipfile" ], "chunk_id": "function_check_disk_space_f4244111" }, { "content": "def get_google_drive_file_info(link: str) -> Tuple[str, str]:\n \"\"\"\n Retrieve the direct download link and filename for a shareable Google Drive file link.\n\n Args:\n link (str): The shareable link of the Google Drive file.\n\n Returns:\n url (str): Direct download URL for the Google Drive file.\n filename (str | None): Original filename of the Google Drive file. If filename extraction fails, returns None.\n\n Examples:\n >>> from ultralytics.utils.downloads import get_google_drive_file_info\n >>> link = \"https://drive.google.com/file/d/1cqT-cJgANNrhIHCrEufUYhQ4RqiWG_lJ/view?usp=drive_link\"\n >>> url, filename = get_google_drive_file_info(link)\n \"\"\"\n import requests # slow import\n\n file_id = link.split(\"/d/\")[1].split(\"/view\", 1)[0]\n drive_url = f\"https://drive.google.com/uc?export=download&id={file_id}\"\n filename = None\n\n # Start session\n with requests.Session() as session:\n response = session.get(drive_url, stream=True)\n if \"quota exceeded\" in str(response.content.lower()):\n raise ConnectionError(\n emojis(\n f\"❌ Google Drive file download quota exceeded. \"\n f\"Please try again later or download this file manually at {link}.\"\n )\n )\n for k, v in response.cookies.items():\n if k.startswith(\"download_warning\"):\n drive_url += f\"&confirm={v}\" # v is token\n if cd := response.headers.get(\"content-disposition\"):\n filename = re.findall('filename=\"(.+)\"', cd)[0]\n return drive_url, filename", "chunk_type": "function", "name": "get_google_drive_file_info", "file_path": "ultralytics\\ultralytics\\utils\\downloads.py", "start_line": 247, "end_line": 284, "start_col": 0, "end_col": 30, "parent_name": null, "docstring": "Retrieve the direct download link and filename for a shareable Google Drive file link.\n\nArgs:\n link (str): The shareable link of the Google Drive file.\n\nReturns:\n url (str): Direct download URL for the Google Drive file.\n filename (str | None): Original filename of the Google Drive file. If filename extraction fails, returns None.\n\nExamples:\n >>> from ultralytics.utils.downloads import get_google_drive_file_info\n >>> link = \"https://drive.google.com/file/d/1cqT-cJgANNrhIHCrEufUYhQ4RqiWG_lJ/view?usp=drive_link\"\n >>> url, filename = get_google_drive_file_info(link)", "parameters": [ "link: str" ], "return_type": "Tuple[str, str]", "decorators": [], "complexity_score": 5, "dependencies": [ "re", "shutil", "subprocess", "itertools.repeat", "multiprocessing.pool.ThreadPool", "pathlib.Path", "typing.List", "typing.Tuple", "urllib.parse", "urllib.request", "torch", "ultralytics.utils.LOGGER", "ultralytics.utils.TQDM", "ultralytics.utils.checks", "ultralytics.utils.clean_url", "ultralytics.utils.emojis", "ultralytics.utils.is_online", "ultralytics.utils.url2file", "zipfile.ZIP_DEFLATED", "zipfile.ZIP_STORED", "zipfile.ZipFile", "zipfile.BadZipFile", "zipfile.ZipFile", "zipfile.is_zipfile", "requests", "requests", "requests", "ultralytics.utils.SETTINGS", "zipfile.is_zipfile" ], "chunk_id": "function_get_google_drive_file_info_3e2a9616" }, { "content": "def safe_download(\n url,\n file=None,\n dir=None,\n unzip: bool = True,\n delete: bool = False,\n curl: bool = False,\n retry: int = 3,\n min_bytes: float = 1e0,\n exist_ok: bool = False,\n progress: bool = True,\n):\n \"\"\"\n Download files from a URL with options for retrying, unzipping, and deleting the downloaded file.\n\n Args:\n url (str): The URL of the file to be downloaded.\n file (str, optional): The filename of the downloaded file.\n If not provided, the file will be saved with the same name as the URL.\n dir (str | Path, optional): The directory to save the downloaded file.\n If not provided, the file will be saved in the current working directory.\n unzip (bool, optional): Whether to unzip the downloaded file.\n delete (bool, optional): Whether to delete the downloaded file after unzipping.\n curl (bool, optional): Whether to use curl command line tool for downloading.\n retry (int, optional): The number of times to retry the download in case of failure.\n min_bytes (float, optional): The minimum number of bytes that the downloaded file should have, to be considered\n a successful download.\n exist_ok (bool, optional): Whether to overwrite existing contents during unzipping.\n progress (bool, optional): Whether to display a progress bar during the download.\n\n Returns:\n (Path | str): The path to the downloaded file or extracted directory.\n\n Examples:\n >>> from ultralytics.utils.downloads import safe_download\n >>> link = \"https://ultralytics.com/assets/bus.jpg\"\n >>> path = safe_download(link)\n \"\"\"\n gdrive = url.startswith(\"https://drive.google.com/\") # check if the URL is a Google Drive link\n if gdrive:\n url, file = get_google_drive_file_info(url)\n\n f = Path(dir or \".\") / (file or url2file(url)) # URL converted to filename\n if \"://\" not in str(url) and Path(url).is_file(): # URL exists ('://' check required in Windows Python<3.10)\n f = Path(url) # filename\n elif not f.is_file(): # URL and file do not exist\n uri = (url if gdrive else clean_url(url)).replace( # cleaned and aliased url\n \"https://github.com/ultralytics/assets/releases/download/v0.0.0/\",\n \"https://ultralytics.com/assets/\", # assets alias\n )\n desc = f\"Downloading {uri} to '{f}'\"\n LOGGER.info(f\"{desc}...\")\n f.parent.mkdir(parents=True, exist_ok=True) # make directory if missing\n check_disk_space(url, path=f.parent)\n curl_installed = shutil.which(\"curl\")\n for i in range(retry + 1):\n try:\n if (curl or i > 0) and curl_installed: # curl download with retry, continue\n s = \"sS\" * (not progress) # silent\n r = subprocess.run([\"curl\", \"-#\", f\"-{s}L\", url, \"-o\", f, \"--retry\", \"3\", \"-C\", \"-\"]).returncode\n assert r == 0, f\"Curl return value {r}\"\n else: # urllib download\n method = \"torch\"\n if method == \"torch\":\n torch.hub.download_url_to_file(url, f, progress=progress)\n else:\n with request.urlopen(url) as response, TQDM(\n total=int(response.getheader(\"Content-Length\", 0)),\n desc=desc,\n disable=not progress,\n unit=\"B\",\n unit_scale=True,\n unit_divisor=1024,\n ) as pbar:\n with open(f, \"wb\") as f_opened:\n for data in response:\n f_opened.write(data)\n pbar.update(len(data))\n\n if f.exists():\n if f.stat().st_size > min_bytes:\n break # success\n f.unlink() # remove partial downloads\n except Exception as e:\n if i == 0 and not is_online():\n raise ConnectionError(emojis(f\"❌ Download failure for {uri}. Environment is not online.\")) from e\n elif i >= retry:\n raise ConnectionError(emojis(f\"❌ Download failure for {uri}. Retry limit reached.\")) from e\n LOGGER.warning(f\"Download failure, retrying {i + 1}/{retry} {uri}...\")\n\n if unzip and f.exists() and f.suffix in {\"\", \".zip\", \".tar\", \".gz\"}:\n from zipfile import is_zipfile\n\n unzip_dir = (dir or f.parent).resolve() # unzip to dir if provided else unzip in place\n if is_zipfile(f):\n unzip_dir = unzip_file(file=f, path=unzip_dir, exist_ok=exist_ok, progress=progress) # unzip\n elif f.suffix in {\".tar\", \".gz\"}:\n LOGGER.info(f\"Unzipping {f} to {unzip_dir}...\")\n subprocess.run([\"tar\", \"xf\" if f.suffix == \".tar\" else \"xfz\", f, \"--directory\", unzip_dir], check=True)\n if delete:\n f.unlink() # remove zip\n return unzip_dir\n return f", "chunk_type": "function", "name": "safe_download", "file_path": "ultralytics\\ultralytics\\utils\\downloads.py", "start_line": 287, "end_line": 389, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": "Download files from a URL with options for retrying, unzipping, and deleting the downloaded file.\n\nArgs:\n url (str): The URL of the file to be downloaded.\n file (str, optional): The filename of the downloaded file.\n If not provided, the file will be saved with the same name as the URL.\n dir (str | Path, optional): The directory to save the downloaded file.\n If not provided, the file will be saved in the current working directory.\n unzip (bool, optional): Whether to unzip the downloaded file.\n delete (bool, optional): Whether to delete the downloaded file after unzipping.\n curl (bool, optional): Whether to use curl command line tool for downloading.\n retry (int, optional): The number of times to retry the download in case of failure.\n min_bytes (float, optional): The minimum number of bytes that the downloaded file should have, to be considered\n a successful download.\n exist_ok (bool, optional): Whether to overwrite existing contents during unzipping.\n progress (bool, optional): Whether to display a progress bar during the download.\n\nReturns:\n (Path | str): The path to the downloaded file or extracted directory.\n\nExamples:\n >>> from ultralytics.utils.downloads import safe_download\n >>> link = \"https://ultralytics.com/assets/bus.jpg\"\n >>> path = safe_download(link)", "parameters": [ "url", "file", "dir", "unzip: bool", "delete: bool", "curl: bool", "retry: int", "min_bytes: float", "exist_ok: bool", "progress: bool" ], "return_type": null, "decorators": [], "complexity_score": 17, "dependencies": [ "re", "shutil", "subprocess", "itertools.repeat", "multiprocessing.pool.ThreadPool", "pathlib.Path", "typing.List", "typing.Tuple", "urllib.parse", "urllib.request", "torch", "ultralytics.utils.LOGGER", "ultralytics.utils.TQDM", "ultralytics.utils.checks", "ultralytics.utils.clean_url", "ultralytics.utils.emojis", "ultralytics.utils.is_online", "ultralytics.utils.url2file", "zipfile.ZIP_DEFLATED", "zipfile.ZIP_STORED", "zipfile.ZipFile", "zipfile.BadZipFile", "zipfile.ZipFile", "zipfile.is_zipfile", "requests", "requests", "requests", "ultralytics.utils.SETTINGS", "zipfile.is_zipfile" ], "chunk_id": "function_safe_download_915d599d" }, { "content": "def get_github_assets(\n repo: str = \"ultralytics/assets\",\n version: str = \"latest\",\n retry: bool = False,\n) -> Tuple[str, List[str]]:\n \"\"\"\n Retrieve the specified version's tag and assets from a GitHub repository.\n\n If the version is not specified, the function fetches the latest release assets.\n\n Args:\n repo (str, optional): The GitHub repository in the format 'owner/repo'.\n version (str, optional): The release version to fetch assets from.\n retry (bool, optional): Flag to retry the request in case of a failure.\n\n Returns:\n tag (str): The release tag.\n assets (List[str]): A list of asset names.\n\n Examples:\n >>> tag, assets = get_github_assets(repo=\"ultralytics/assets\", version=\"latest\")\n \"\"\"\n import requests # slow import\n\n if version != \"latest\":\n version = f\"tags/{version}\" # i.e. tags/v6.2\n url = f\"https://api.github.com/repos/{repo}/releases/{version}\"\n r = requests.get(url) # github api\n if r.status_code != 200 and r.reason != \"rate limit exceeded\" and retry: # failed and not 403 rate limit exceeded\n r = requests.get(url) # try again\n if r.status_code != 200:\n LOGGER.warning(f\"GitHub assets check failure for {url}: {r.status_code} {r.reason}\")\n return \"\", []\n data = r.json()\n return data[\"tag_name\"], [x[\"name\"] for x in data[\"assets\"]] # tag, assets i.e. ['yolo11n.pt', 'yolov8s.pt', ...]", "chunk_type": "function", "name": "get_github_assets", "file_path": "ultralytics\\ultralytics\\utils\\downloads.py", "start_line": 392, "end_line": 426, "start_col": 0, "end_col": 64, "parent_name": null, "docstring": "Retrieve the specified version's tag and assets from a GitHub repository.\n\nIf the version is not specified, the function fetches the latest release assets.\n\nArgs:\n repo (str, optional): The GitHub repository in the format 'owner/repo'.\n version (str, optional): The release version to fetch assets from.\n retry (bool, optional): Flag to retry the request in case of a failure.\n\nReturns:\n tag (str): The release tag.\n assets (List[str]): A list of asset names.\n\nExamples:\n >>> tag, assets = get_github_assets(repo=\"ultralytics/assets\", version=\"latest\")", "parameters": [ "repo: str", "version: str", "retry: bool" ], "return_type": "Tuple[str, List[str]]", "decorators": [], "complexity_score": 5, "dependencies": [ "re", "shutil", "subprocess", "itertools.repeat", "multiprocessing.pool.ThreadPool", "pathlib.Path", "typing.List", "typing.Tuple", "urllib.parse", "urllib.request", "torch", "ultralytics.utils.LOGGER", "ultralytics.utils.TQDM", "ultralytics.utils.checks", "ultralytics.utils.clean_url", "ultralytics.utils.emojis", "ultralytics.utils.is_online", "ultralytics.utils.url2file", "zipfile.ZIP_DEFLATED", "zipfile.ZIP_STORED", "zipfile.ZipFile", "zipfile.BadZipFile", "zipfile.ZipFile", "zipfile.is_zipfile", "requests", "requests", "requests", "ultralytics.utils.SETTINGS", "zipfile.is_zipfile" ], "chunk_id": "function_get_github_assets_f59536ca" }, { "content": "def attempt_download_asset(file, repo: str = \"ultralytics/assets\", release: str = \"v8.3.0\", **kwargs) -> str:\n \"\"\"\n Attempt to download a file from GitHub release assets if it is not found locally.\n\n Args:\n file (str | Path): The filename or file path to be downloaded.\n repo (str, optional): The GitHub repository in the format 'owner/repo'.\n release (str, optional): The specific release version to be downloaded.\n **kwargs (Any): Additional keyword arguments for the download process.\n\n Returns:\n (str): The path to the downloaded file.\n\n Examples:\n >>> file_path = attempt_download_asset(\"yolo11n.pt\", repo=\"ultralytics/assets\", release=\"latest\")\n \"\"\"\n from ultralytics.utils import SETTINGS # scoped for circular import\n\n # YOLOv3/5u updates\n file = str(file)\n file = checks.check_yolov5u_filename(file)\n file = Path(file.strip().replace(\"'\", \"\"))\n if file.exists():\n return str(file)\n elif (SETTINGS[\"weights_dir\"] / file).exists():\n return str(SETTINGS[\"weights_dir\"] / file)\n else:\n # URL specified\n name = Path(parse.unquote(str(file))).name # decode '%2F' to '/' etc.\n download_url = f\"https://github.com/{repo}/releases/download\"\n if str(file).startswith((\"http:/\", \"https:/\")): # download\n url = str(file).replace(\":/\", \"://\") # Pathlib turns :// -> :/\n file = url2file(name) # parse authentication https://url.com/file.txt?auth...\n if Path(file).is_file():\n LOGGER.info(f\"Found {clean_url(url)} locally at {file}\") # file already exists\n else:\n safe_download(url=url, file=file, min_bytes=1e5, **kwargs)\n\n elif repo == GITHUB_ASSETS_REPO and name in GITHUB_ASSETS_NAMES:\n safe_download(url=f\"{download_url}/{release}/{name}\", file=file, min_bytes=1e5, **kwargs)\n\n else:\n tag, assets = get_github_assets(repo, release)\n if not assets:\n tag, assets = get_github_assets(repo) # latest release\n if name in assets:\n safe_download(url=f\"{download_url}/{tag}/{name}\", file=file, min_bytes=1e5, **kwargs)\n\n return str(file)", "chunk_type": "function", "name": "attempt_download_asset", "file_path": "ultralytics\\ultralytics\\utils\\downloads.py", "start_line": 429, "end_line": 477, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": "Attempt to download a file from GitHub release assets if it is not found locally.\n\nArgs:\n file (str | Path): The filename or file path to be downloaded.\n repo (str, optional): The GitHub repository in the format 'owner/repo'.\n release (str, optional): The specific release version to be downloaded.\n **kwargs (Any): Additional keyword arguments for the download process.\n\nReturns:\n (str): The path to the downloaded file.\n\nExamples:\n >>> file_path = attempt_download_asset(\"yolo11n.pt\", repo=\"ultralytics/assets\", release=\"latest\")", "parameters": [ "file", "repo: str", "release: str" ], "return_type": "str", "decorators": [], "complexity_score": 8, "dependencies": [ "re", "shutil", "subprocess", "itertools.repeat", "multiprocessing.pool.ThreadPool", "pathlib.Path", "typing.List", "typing.Tuple", "urllib.parse", "urllib.request", "torch", "ultralytics.utils.LOGGER", "ultralytics.utils.TQDM", "ultralytics.utils.checks", "ultralytics.utils.clean_url", "ultralytics.utils.emojis", "ultralytics.utils.is_online", "ultralytics.utils.url2file", "zipfile.ZIP_DEFLATED", "zipfile.ZIP_STORED", "zipfile.ZipFile", "zipfile.BadZipFile", "zipfile.ZipFile", "zipfile.is_zipfile", "requests", "requests", "requests", "ultralytics.utils.SETTINGS", "zipfile.is_zipfile" ], "chunk_id": "function_attempt_download_asset_d74cab11" }, { "content": "def download(\n url,\n dir=Path.cwd(),\n unzip: bool = True,\n delete: bool = False,\n curl: bool = False,\n threads: int = 1,\n retry: int = 3,\n exist_ok: bool = False,\n):\n \"\"\"\n Download files from specified URLs to a given directory.\n\n Supports concurrent downloads if multiple threads are specified.\n\n Args:\n url (str | List[str]): The URL or list of URLs of the files to be downloaded.\n dir (Path, optional): The directory where the files will be saved.\n unzip (bool, optional): Flag to unzip the files after downloading.\n delete (bool, optional): Flag to delete the zip files after extraction.\n curl (bool, optional): Flag to use curl for downloading.\n threads (int, optional): Number of threads to use for concurrent downloads.\n retry (int, optional): Number of retries in case of download failure.\n exist_ok (bool, optional): Whether to overwrite existing contents during unzipping.\n\n Examples:\n >>> download(\"https://ultralytics.com/assets/example.zip\", dir=\"path/to/dir\", unzip=True)\n \"\"\"\n dir = Path(dir)\n dir.mkdir(parents=True, exist_ok=True) # make directory\n if threads > 1:\n with ThreadPool(threads) as pool:\n pool.map(\n lambda x: safe_download(\n url=x[0],\n dir=x[1],\n unzip=unzip,\n delete=delete,\n curl=curl,\n retry=retry,\n exist_ok=exist_ok,\n progress=threads <= 1,\n ),\n zip(url, repeat(dir)),\n )\n pool.close()\n pool.join()\n else:\n for u in [url] if isinstance(url, (str, Path)) else url:\n safe_download(url=u, dir=dir, unzip=unzip, delete=delete, curl=curl, retry=retry, exist_ok=exist_ok)", "chunk_type": "function", "name": "download", "file_path": "ultralytics\\ultralytics\\utils\\downloads.py", "start_line": 480, "end_line": 529, "start_col": 0, "end_col": 112, "parent_name": null, "docstring": "Download files from specified URLs to a given directory.\n\nSupports concurrent downloads if multiple threads are specified.\n\nArgs:\n url (str | List[str]): The URL or list of URLs of the files to be downloaded.\n dir (Path, optional): The directory where the files will be saved.\n unzip (bool, optional): Flag to unzip the files after downloading.\n delete (bool, optional): Flag to delete the zip files after extraction.\n curl (bool, optional): Flag to use curl for downloading.\n threads (int, optional): Number of threads to use for concurrent downloads.\n retry (int, optional): Number of retries in case of download failure.\n exist_ok (bool, optional): Whether to overwrite existing contents during unzipping.\n\nExamples:\n >>> download(\"https://ultralytics.com/assets/example.zip\", dir=\"path/to/dir\", unzip=True)", "parameters": [ "url", "dir", "unzip: bool", "delete: bool", "curl: bool", "threads: int", "retry: int", "exist_ok: bool" ], "return_type": null, "decorators": [], "complexity_score": 3, "dependencies": [ "re", "shutil", "subprocess", "itertools.repeat", "multiprocessing.pool.ThreadPool", "pathlib.Path", "typing.List", "typing.Tuple", "urllib.parse", "urllib.request", "torch", "ultralytics.utils.LOGGER", "ultralytics.utils.TQDM", "ultralytics.utils.checks", "ultralytics.utils.clean_url", "ultralytics.utils.emojis", "ultralytics.utils.is_online", "ultralytics.utils.url2file", "zipfile.ZIP_DEFLATED", "zipfile.ZIP_STORED", "zipfile.ZipFile", "zipfile.BadZipFile", "zipfile.ZipFile", "zipfile.is_zipfile", "requests", "requests", "requests", "ultralytics.utils.SETTINGS", "zipfile.is_zipfile" ], "chunk_id": "function_download_a573ab84" }, { "content": "from ultralytics.utils import emojis", "chunk_type": "import", "name": "emojis", "file_path": "ultralytics\\ultralytics\\utils\\errors.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 36, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_emojis_003e07a6" }, { "content": "class HUBModelError(Exception):\n \"\"\"\n Exception raised when a model cannot be found or retrieved from Ultralytics HUB.\n\n This custom exception is used specifically for handling errors related to model fetching in Ultralytics YOLO.\n The error message is processed to include emojis for better user experience.\n\n Attributes:\n message (str): The error message displayed when the exception is raised.\n\n Methods:\n __init__: Initialize the HUBModelError with a custom message.\n\n Examples:\n >>> try:\n ... # Code that might fail to find a model\n ... raise HUBModelError(\"Custom model not found message\")\n ... except HUBModelError as e:\n ... print(e) # Displays the emoji-enhanced error message\n \"\"\"\n\n def __init__(self, message: str = \"Model not found. Please check model URL and try again.\"):\n \"\"\"\n Initialize a HUBModelError exception.\n\n This exception is raised when a requested model is not found or cannot be retrieved from Ultralytics HUB.\n The message is processed to include emojis for better user experience.\n\n Args:\n message (str, optional): The error message to display when the exception is raised.\n\n Examples:\n >>> try:\n ... raise HUBModelError(\"Custom model error message\")\n ... except HUBModelError as e:\n ... print(e)\n \"\"\"\n super().__init__(emojis(message))", "chunk_type": "class", "name": "HUBModelError", "file_path": "ultralytics\\ultralytics\\utils\\errors.py", "start_line": 6, "end_line": 43, "start_col": 0, "end_col": 41, "parent_name": null, "docstring": "Exception raised when a model cannot be found or retrieved from Ultralytics HUB.\n\nThis custom exception is used specifically for handling errors related to model fetching in Ultralytics YOLO.\nThe error message is processed to include emojis for better user experience.\n\nAttributes:\n message (str): The error message displayed when the exception is raised.\n\nMethods:\n __init__: Initialize the HUBModelError with a custom message.\n\nExamples:\n >>> try:\n ... # Code that might fail to find a model\n ... raise HUBModelError(\"Custom model not found message\")\n ... except HUBModelError as e:\n ... print(e) # Displays the emoji-enhanced error message", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "ultralytics.utils.emojis", "Exception" ], "chunk_id": "class_HUBModelError_4875e1d8" }, { "content": "import json", "chunk_type": "import", "name": "json", "file_path": "ultralytics\\ultralytics\\utils\\export.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 11, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_json_83c7175c" }, { "content": "from pathlib import Path", "chunk_type": "import", "name": "Path", "file_path": "ultralytics\\ultralytics\\utils\\export.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Path_c24a6087" }, { "content": "from typing import Dict, List, Optional, Tuple, Union", "chunk_type": "import", "name": "Dict, List, Optional, Tuple, Union", "file_path": "ultralytics\\ultralytics\\utils\\export.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 53, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Dict, List, Optional, Tuple, Union_31e6da6e" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\ultralytics\\utils\\export.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_4e281efa" }, { "content": "from ultralytics.utils import IS_JETSON, LOGGER", "chunk_type": "import", "name": "IS_JETSON, LOGGER", "file_path": "ultralytics\\ultralytics\\utils\\export.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 47, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_IS_JETSON, LOGGER_ae878683" }, { "content": "def export_onnx(\n torch_model: torch.nn.Module,\n im: torch.Tensor,\n onnx_file: str,\n opset: int = 14,\n input_names: List[str] = [\"images\"],\n output_names: List[str] = [\"output0\"],\n dynamic: Union[bool, Dict] = False,\n) -> None:\n \"\"\"\n Export a PyTorch model to ONNX format.\n\n Args:\n torch_model (torch.nn.Module): The PyTorch model to export.\n im (torch.Tensor): Example input tensor for the model.\n onnx_file (str): Path to save the exported ONNX file.\n opset (int): ONNX opset version to use for export.\n input_names (List[str]): List of input tensor names.\n output_names (List[str]): List of output tensor names.\n dynamic (bool | Dict, optional): Whether to enable dynamic axes.\n\n Notes:\n Setting `do_constant_folding=True` may cause issues with DNN inference for torch>=1.12.\n \"\"\"\n torch.onnx.export(\n torch_model,\n im,\n onnx_file,\n verbose=False,\n opset_version=opset,\n do_constant_folding=True, # WARNING: DNN inference with torch>=1.12 may require do_constant_folding=False\n input_names=input_names,\n output_names=output_names,\n dynamic_axes=dynamic or None,\n )", "chunk_type": "function", "name": "export_onnx", "file_path": "ultralytics\\ultralytics\\utils\\export.py", "start_line": 12, "end_line": 46, "start_col": 0, "end_col": 5, "parent_name": null, "docstring": "Export a PyTorch model to ONNX format.\n\nArgs:\n torch_model (torch.nn.Module): The PyTorch model to export.\n im (torch.Tensor): Example input tensor for the model.\n onnx_file (str): Path to save the exported ONNX file.\n opset (int): ONNX opset version to use for export.\n input_names (List[str]): List of input tensor names.\n output_names (List[str]): List of output tensor names.\n dynamic (bool | Dict, optional): Whether to enable dynamic axes.\n\nNotes:\n Setting `do_constant_folding=True` may cause issues with DNN inference for torch>=1.12.", "parameters": [ "torch_model: torch.nn.Module", "im: torch.Tensor", "onnx_file: str", "opset: int", "input_names: List[str]", "output_names: List[str]", "dynamic: Union[bool, Dict]" ], "return_type": "None", "decorators": [], "complexity_score": 1, "dependencies": [ "json", "pathlib.Path", "typing.Dict", "typing.List", "typing.Optional", "typing.Tuple", "typing.Union", "torch", "ultralytics.utils.IS_JETSON", "ultralytics.utils.LOGGER", "tensorrt" ], "chunk_id": "function_export_onnx_918c7e90" }, { "content": "def export_engine(\n onnx_file: str,\n engine_file: Optional[str] = None,\n workspace: Optional[int] = None,\n half: bool = False,\n int8: bool = False,\n dynamic: bool = False,\n shape: Tuple[int, int, int, int] = (1, 3, 640, 640),\n dla: Optional[int] = None,\n dataset=None,\n metadata: Optional[Dict] = None,\n verbose: bool = False,\n prefix: str = \"\",\n) -> None:\n \"\"\"\n Export a YOLO model to TensorRT engine format.\n\n Args:\n onnx_file (str): Path to the ONNX file to be converted.\n engine_file (str, optional): Path to save the generated TensorRT engine file.\n workspace (int, optional): Workspace size in GB for TensorRT.\n half (bool, optional): Enable FP16 precision.\n int8 (bool, optional): Enable INT8 precision.\n dynamic (bool, optional): Enable dynamic input shapes.\n shape (Tuple[int, int, int, int], optional): Input shape (batch, channels, height, width).\n dla (int, optional): DLA core to use (Jetson devices only).\n dataset (ultralytics.data.build.InfiniteDataLoader, optional): Dataset for INT8 calibration.\n metadata (Dict, optional): Metadata to include in the engine file.\n verbose (bool, optional): Enable verbose logging.\n prefix (str, optional): Prefix for log messages.\n\n Raises:\n ValueError: If DLA is enabled on non-Jetson devices or required precision is not set.\n RuntimeError: If the ONNX file cannot be parsed.\n\n Notes:\n TensorRT version compatibility is handled for workspace size and engine building.\n INT8 calibration requires a dataset and generates a calibration cache.\n Metadata is serialized and written to the engine file if provided.\n \"\"\"\n import tensorrt as trt # noqa\n\n engine_file = engine_file or Path(onnx_file).with_suffix(\".engine\")\n\n logger = trt.Logger(trt.Logger.INFO)\n if verbose:\n logger.min_severity = trt.Logger.Severity.VERBOSE\n\n # Engine builder\n builder = trt.Builder(logger)\n config = builder.create_builder_config()\n workspace = int((workspace or 0) * (1 << 30))\n is_trt10 = int(trt.__version__.split(\".\", 1)[0]) >= 10 # is TensorRT >= 10\n if is_trt10 and workspace > 0:\n config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace)\n elif workspace > 0: # TensorRT versions 7, 8\n config.max_workspace_size = workspace\n flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)\n network = builder.create_network(flag)\n half = builder.platform_has_fast_fp16 and half\n int8 = builder.platform_has_fast_int8 and int8\n\n # Optionally switch to DLA if enabled\n if dla is not None:\n if not IS_JETSON:\n raise ValueError(\"DLA is only available on NVIDIA Jetson devices\")\n LOGGER.info(f\"{prefix} enabling DLA on core {dla}...\")\n if not half and not int8:\n raise ValueError(\n \"DLA requires either 'half=True' (FP16) or 'int8=True' (INT8) to be enabled. Please enable one of them and try again.\"\n )\n config.default_device_type = trt.DeviceType.DLA\n config.DLA_core = int(dla)\n config.set_flag(trt.BuilderFlag.GPU_FALLBACK)\n\n # Read ONNX file\n parser = trt.OnnxParser(network, logger)\n if not parser.parse_from_file(onnx_file):\n raise RuntimeError(f\"failed to load ONNX file: {onnx_file}\")\n\n # Network inputs\n inputs = [network.get_input(i) for i in range(network.num_inputs)]\n outputs = [network.get_output(i) for i in range(network.num_outputs)]\n for inp in inputs:\n LOGGER.info(f'{prefix} input \"{inp.name}\" with shape{inp.shape} {inp.dtype}')\n for out in outputs:\n LOGGER.info(f'{prefix} output \"{out.name}\" with shape{out.shape} {out.dtype}')\n\n if dynamic:\n profile = builder.create_optimization_profile()\n min_shape = (1, shape[1], 32, 32) # minimum input shape\n max_shape = (*shape[:2], *(int(max(2, workspace or 2) * d) for d in shape[2:])) # max input shape\n for inp in inputs:\n profile.set_shape(inp.name, min=min_shape, opt=shape, max=max_shape)\n config.add_optimization_profile(profile)\n if int8:\n config.set_calibration_profile(profile)\n\n LOGGER.info(f\"{prefix} building {'INT8' if int8 else 'FP' + ('16' if half else '32')} engine as {engine_file}\")\n if int8:\n config.set_flag(trt.BuilderFlag.INT8)\n config.profiling_verbosity = trt.ProfilingVerbosity.DETAILED\n\n class EngineCalibrator(trt.IInt8Calibrator):\n \"\"\"\n Custom INT8 calibrator for TensorRT engine optimization.\n\n This calibrator provides the necessary interface for TensorRT to perform INT8 quantization calibration\n using a dataset. It handles batch generation, caching, and calibration algorithm selection.\n\n Attributes:\n dataset: Dataset for calibration.\n data_iter: Iterator over the calibration dataset.\n algo (trt.CalibrationAlgoType): Calibration algorithm type.\n batch (int): Batch size for calibration.\n cache (Path): Path to save the calibration cache.\n\n Methods:\n get_algorithm: Get the calibration algorithm to use.\n get_batch_size: Get the batch size to use for calibration.\n get_batch: Get the next batch to use for calibration.\n read_calibration_cache: Use existing cache instead of calibrating again.\n write_calibration_cache: Write calibration cache to disk.\n \"\"\"\n\n def __init__(\n self,\n dataset, # ultralytics.data.build.InfiniteDataLoader\n cache: str = \"\",\n ) -> None:\n \"\"\"Initialize the INT8 calibrator with dataset and cache path.\"\"\"\n trt.IInt8Calibrator.__init__(self)\n self.dataset = dataset\n self.data_iter = iter(dataset)\n self.algo = (\n trt.CalibrationAlgoType.ENTROPY_CALIBRATION_2 # DLA quantization needs ENTROPY_CALIBRATION_2\n if dla is not None\n else trt.CalibrationAlgoType.MINMAX_CALIBRATION\n )\n self.batch = dataset.batch_size\n self.cache = Path(cache)\n\n def get_algorithm(self) -> trt.CalibrationAlgoType:\n \"\"\"Get the calibration algorithm to use.\"\"\"\n return self.algo\n\n def get_batch_size(self) -> int:\n \"\"\"Get the batch size to use for calibration.\"\"\"\n return self.batch or 1\n\n def get_batch(self, names) -> Optional[List[int]]:\n \"\"\"Get the next batch to use for calibration, as a list of device memory pointers.\"\"\"\n try:\n im0s = next(self.data_iter)[\"img\"] / 255.0\n im0s = im0s.to(\"cuda\") if im0s.device.type == \"cpu\" else im0s\n return [int(im0s.data_ptr())]\n except StopIteration:\n # Return None to signal to TensorRT there is no calibration data remaining\n return None\n\n def read_calibration_cache(self) -> Optional[bytes]:\n \"\"\"Use existing cache instead of calibrating again, otherwise, implicitly return None.\"\"\"\n if self.cache.exists() and self.cache.suffix == \".cache\":\n return self.cache.read_bytes()\n\n def write_calibration_cache(self, cache: bytes) -> None:\n \"\"\"Write calibration cache to disk.\"\"\"\n _ = self.cache.write_bytes(cache)\n\n # Load dataset w/ builder (for batching) and calibrate\n config.int8_calibrator = EngineCalibrator(\n dataset=dataset,\n cache=str(Path(onnx_file).with_suffix(\".cache\")),\n )\n\n elif half:\n config.set_flag(trt.BuilderFlag.FP16)\n\n # Write file\n build = builder.build_serialized_network if is_trt10 else builder.build_engine\n with build(network, config) as engine, open(engine_file, \"wb\") as t:\n # Metadata\n if metadata is not None:\n meta = json.dumps(metadata)\n t.write(len(meta).to_bytes(4, byteorder=\"little\", signed=True))\n t.write(meta.encode())\n # Model\n t.write(engine if is_trt10 else engine.serialize())", "chunk_type": "function", "name": "export_engine", "file_path": "ultralytics\\ultralytics\\utils\\export.py", "start_line": 49, "end_line": 236, "start_col": 0, "end_col": 59, "parent_name": null, "docstring": "Export a YOLO model to TensorRT engine format.\n\nArgs:\n onnx_file (str): Path to the ONNX file to be converted.\n engine_file (str, optional): Path to save the generated TensorRT engine file.\n workspace (int, optional): Workspace size in GB for TensorRT.\n half (bool, optional): Enable FP16 precision.\n int8 (bool, optional): Enable INT8 precision.\n dynamic (bool, optional): Enable dynamic input shapes.\n shape (Tuple[int, int, int, int], optional): Input shape (batch, channels, height, width).\n dla (int, optional): DLA core to use (Jetson devices only).\n dataset (ultralytics.data.build.InfiniteDataLoader, optional): Dataset for INT8 calibration.\n metadata (Dict, optional): Metadata to include in the engine file.\n verbose (bool, optional): Enable verbose logging.\n prefix (str, optional): Prefix for log messages.\n\nRaises:\n ValueError: If DLA is enabled on non-Jetson devices or required precision is not set.\n RuntimeError: If the ONNX file cannot be parsed.\n\nNotes:\n TensorRT version compatibility is handled for workspace size and engine building.\n INT8 calibration requires a dataset and generates a calibration cache.\n Metadata is serialized and written to the engine file if provided.", "parameters": [ "onnx_file: str", "engine_file: Optional[str]", "workspace: Optional[int]", "half: bool", "int8: bool", "dynamic: bool", "shape: Tuple[int, int, int, int]", "dla: Optional[int]", "dataset", "metadata: Optional[Dict]", "verbose: bool", "prefix: str" ], "return_type": "None", "decorators": [], "complexity_score": 21, "dependencies": [ "json", "pathlib.Path", "typing.Dict", "typing.List", "typing.Optional", "typing.Tuple", "typing.Union", "torch", "ultralytics.utils.IS_JETSON", "ultralytics.utils.LOGGER", "tensorrt" ], "chunk_id": "function_export_engine_5e8b025f" }, { "content": "import contextlib", "chunk_type": "import", "name": "contextlib", "file_path": "ultralytics\\ultralytics\\utils\\files.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 17, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_contextlib_bacf7ad7" }, { "content": "import glob", "chunk_type": "import", "name": "glob", "file_path": "ultralytics\\ultralytics\\utils\\files.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 11, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_glob_b6b5d75e" }, { "content": "import os", "chunk_type": "import", "name": "os", "file_path": "ultralytics\\ultralytics\\utils\\files.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 9, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_os_778704ca" }, { "content": "import shutil", "chunk_type": "import", "name": "shutil", "file_path": "ultralytics\\ultralytics\\utils\\files.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 13, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_shutil_3896adf2" }, { "content": "import tempfile", "chunk_type": "import", "name": "tempfile", "file_path": "ultralytics\\ultralytics\\utils\\files.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 15, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_tempfile_5965034a" }, { "content": "from contextlib import contextmanager", "chunk_type": "import", "name": "contextmanager", "file_path": "ultralytics\\ultralytics\\utils\\files.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 37, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_contextmanager_91e65499" }, { "content": "from datetime import datetime", "chunk_type": "import", "name": "datetime", "file_path": "ultralytics\\ultralytics\\utils\\files.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 29, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_datetime_ea6c4397" }, { "content": "from pathlib import Path", "chunk_type": "import", "name": "Path", "file_path": "ultralytics\\ultralytics\\utils\\files.py", "start_line": 10, "end_line": 10, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Path_613eb98b" }, { "content": "from typing import Union", "chunk_type": "import", "name": "Union", "file_path": "ultralytics\\ultralytics\\utils\\files.py", "start_line": 11, "end_line": 11, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Union_19336d67" }, { "content": "class WorkingDirectory(contextlib.ContextDecorator):\n \"\"\"\n A context manager and decorator for temporarily changing the working directory.\n\n This class allows for the temporary change of the working directory using a context manager or decorator.\n It ensures that the original working directory is restored after the context or decorated function completes.\n\n Attributes:\n dir (Path | str): The new directory to switch to.\n cwd (Path): The original current working directory before the switch.\n\n Methods:\n __enter__: Changes the current directory to the specified directory.\n __exit__: Restores the original working directory on context exit.\n\n Examples:\n Using as a context manager:\n >>> with WorkingDirectory('/path/to/new/dir'):\n >>> # Perform operations in the new directory\n >>> pass\n\n Using as a decorator:\n >>> @WorkingDirectory('/path/to/new/dir')\n >>> def some_function():\n >>> # Perform operations in the new directory\n >>> pass\n \"\"\"\n\n def __init__(self, new_dir: Union[str, Path]):\n \"\"\"Initialize the WorkingDirectory context manager with the target directory.\"\"\"\n self.dir = new_dir # new dir\n self.cwd = Path.cwd().resolve() # current dir\n\n def __enter__(self):\n \"\"\"Change the current working directory to the specified directory upon entering the context.\"\"\"\n os.chdir(self.dir)\n\n def __exit__(self, exc_type, exc_val, exc_tb): # noqa\n \"\"\"Restore the original working directory when exiting the context.\"\"\"\n os.chdir(self.cwd)", "chunk_type": "class", "name": "WorkingDirectory", "file_path": "ultralytics\\ultralytics\\utils\\files.py", "start_line": 14, "end_line": 53, "start_col": 0, "end_col": 26, "parent_name": null, "docstring": "A context manager and decorator for temporarily changing the working directory.\n\nThis class allows for the temporary change of the working directory using a context manager or decorator.\nIt ensures that the original working directory is restored after the context or decorated function completes.\n\nAttributes:\n dir (Path | str): The new directory to switch to.\n cwd (Path): The original current working directory before the switch.\n\nMethods:\n __enter__: Changes the current directory to the specified directory.\n __exit__: Restores the original working directory on context exit.\n\nExamples:\n Using as a context manager:\n >>> with WorkingDirectory('/path/to/new/dir'):\n >>> # Perform operations in the new directory\n >>> pass\n\n Using as a decorator:\n >>> @WorkingDirectory('/path/to/new/dir')\n >>> def some_function():\n >>> # Perform operations in the new directory\n >>> pass", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "contextlib", "glob", "os", "shutil", "tempfile", "contextlib.contextmanager", "datetime.datetime", "pathlib.Path", "typing.Union", "ultralytics.YOLO", "ultralytics.nn.autobackend.default_class_names", "contextlib.ContextDecorator" ], "chunk_id": "class_WorkingDirectory_98c3f7ce" }, { "content": "def spaces_in_path(path: Union[str, Path]):\n \"\"\"\n Context manager to handle paths with spaces in their names.\n\n If a path contains spaces, it replaces them with underscores, copies the file/directory to the new path, executes\n the context code block, then copies the file/directory back to its original location.\n\n Args:\n path (str | Path): The original path that may contain spaces.\n\n Yields:\n (Path | str): Temporary path with spaces replaced by underscores if spaces were present, otherwise the\n original path.\n\n Examples:\n >>> with spaces_in_path('/path/with spaces') as new_path:\n >>> # Your code here\n >>> pass\n \"\"\"\n # If path has spaces, replace them with underscores\n if \" \" in str(path):\n string = isinstance(path, str) # input type\n path = Path(path)\n\n # Create a temporary directory and construct the new path\n with tempfile.TemporaryDirectory() as tmp_dir:\n tmp_path = Path(tmp_dir) / path.name.replace(\" \", \"_\")\n\n # Copy file/directory\n if path.is_dir():\n shutil.copytree(path, tmp_path)\n elif path.is_file():\n tmp_path.parent.mkdir(parents=True, exist_ok=True)\n shutil.copy2(path, tmp_path)\n\n try:\n # Yield the temporary path\n yield str(tmp_path) if string else tmp_path\n\n finally:\n # Copy file/directory back\n if tmp_path.is_dir():\n shutil.copytree(tmp_path, path, dirs_exist_ok=True)\n elif tmp_path.is_file():\n shutil.copy2(tmp_path, path) # Copy back the file\n\n else:\n # If there are no spaces, just yield the original path\n yield path", "chunk_type": "function", "name": "spaces_in_path", "file_path": "ultralytics\\ultralytics\\utils\\files.py", "start_line": 57, "end_line": 105, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": "Context manager to handle paths with spaces in their names.\n\nIf a path contains spaces, it replaces them with underscores, copies the file/directory to the new path, executes\nthe context code block, then copies the file/directory back to its original location.\n\nArgs:\n path (str | Path): The original path that may contain spaces.\n\nYields:\n (Path | str): Temporary path with spaces replaced by underscores if spaces were present, otherwise the\n original path.\n\nExamples:\n >>> with spaces_in_path('/path/with spaces') as new_path:\n >>> # Your code here\n >>> pass", "parameters": [ "path: Union[str, Path]" ], "return_type": null, "decorators": [ "contextmanager" ], "complexity_score": 6, "dependencies": [ "contextlib", "glob", "os", "shutil", "tempfile", "contextlib.contextmanager", "datetime.datetime", "pathlib.Path", "typing.Union", "ultralytics.YOLO", "ultralytics.nn.autobackend.default_class_names" ], "chunk_id": "function_spaces_in_path_fac3d854" }, { "content": "def increment_path(path: Union[str, Path], exist_ok: bool = False, sep: str = \"\", mkdir: bool = False) -> Path:\n \"\"\"\n Increment a file or directory path, i.e., runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.\n\n If the path exists and `exist_ok` is not True, the path will be incremented by appending a number and `sep` to\n the end of the path. If the path is a file, the file extension will be preserved. If the path is a directory, the\n number will be appended directly to the end of the path.\n\n Args:\n path (str | Path): Path to increment.\n exist_ok (bool, optional): If True, the path will not be incremented and returned as-is.\n sep (str, optional): Separator to use between the path and the incrementation number.\n mkdir (bool, optional): Create a directory if it does not exist.\n\n Returns:\n (Path): Incremented path.\n\n Examples:\n Increment a directory path:\n >>> from pathlib import Path\n >>> path = Path(\"runs/exp\")\n >>> new_path = increment_path(path)\n >>> print(new_path)\n runs/exp2\n\n Increment a file path:\n >>> path = Path(\"runs/exp/results.txt\")\n >>> new_path = increment_path(path)\n >>> print(new_path)\n runs/exp/results2.txt\n \"\"\"\n path = Path(path) # os-agnostic\n if path.exists() and not exist_ok:\n path, suffix = (path.with_suffix(\"\"), path.suffix) if path.is_file() else (path, \"\")\n\n # Method 1\n for n in range(2, 9999):\n p = f\"{path}{sep}{n}{suffix}\" # increment path\n if not os.path.exists(p):\n break\n path = Path(p)\n\n if mkdir:\n path.mkdir(parents=True, exist_ok=True) # make directory\n\n return path", "chunk_type": "function", "name": "increment_path", "file_path": "ultralytics\\ultralytics\\utils\\files.py", "start_line": 108, "end_line": 153, "start_col": 0, "end_col": 15, "parent_name": null, "docstring": "Increment a file or directory path, i.e., runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.\n\nIf the path exists and `exist_ok` is not True, the path will be incremented by appending a number and `sep` to\nthe end of the path. If the path is a file, the file extension will be preserved. If the path is a directory, the\nnumber will be appended directly to the end of the path.\n\nArgs:\n path (str | Path): Path to increment.\n exist_ok (bool, optional): If True, the path will not be incremented and returned as-is.\n sep (str, optional): Separator to use between the path and the incrementation number.\n mkdir (bool, optional): Create a directory if it does not exist.\n\nReturns:\n (Path): Incremented path.\n\nExamples:\n Increment a directory path:\n >>> from pathlib import Path\n >>> path = Path(\"runs/exp\")\n >>> new_path = increment_path(path)\n >>> print(new_path)\n runs/exp2\n\n Increment a file path:\n >>> path = Path(\"runs/exp/results.txt\")\n >>> new_path = increment_path(path)\n >>> print(new_path)\n runs/exp/results2.txt", "parameters": [ "path: Union[str, Path]", "exist_ok: bool", "sep: str", "mkdir: bool" ], "return_type": "Path", "decorators": [], "complexity_score": 5, "dependencies": [ "contextlib", "glob", "os", "shutil", "tempfile", "contextlib.contextmanager", "datetime.datetime", "pathlib.Path", "typing.Union", "ultralytics.YOLO", "ultralytics.nn.autobackend.default_class_names" ], "chunk_id": "function_increment_path_f13e1833" }, { "content": "def file_age(path: Union[str, Path] = __file__) -> int:\n \"\"\"Return days since the last modification of the specified file.\"\"\"\n dt = datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime) # delta\n return dt.days # + dt.seconds / 86400 # fractional days", "chunk_type": "function", "name": "file_age", "file_path": "ultralytics\\ultralytics\\utils\\files.py", "start_line": 156, "end_line": 159, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": "Return days since the last modification of the specified file.", "parameters": [ "path: Union[str, Path]" ], "return_type": "int", "decorators": [], "complexity_score": 1, "dependencies": [ "contextlib", "glob", "os", "shutil", "tempfile", "contextlib.contextmanager", "datetime.datetime", "pathlib.Path", "typing.Union", "ultralytics.YOLO", "ultralytics.nn.autobackend.default_class_names" ], "chunk_id": "function_file_age_ba03809c" }, { "content": "def file_date(path: Union[str, Path] = __file__) -> str:\n \"\"\"Return the file modification date in 'YYYY-M-D' format.\"\"\"\n t = datetime.fromtimestamp(Path(path).stat().st_mtime)\n return f\"{t.year}-{t.month}-{t.day}\"", "chunk_type": "function", "name": "file_date", "file_path": "ultralytics\\ultralytics\\utils\\files.py", "start_line": 162, "end_line": 165, "start_col": 0, "end_col": 40, "parent_name": null, "docstring": "Return the file modification date in 'YYYY-M-D' format.", "parameters": [ "path: Union[str, Path]" ], "return_type": "str", "decorators": [], "complexity_score": 1, "dependencies": [ "contextlib", "glob", "os", "shutil", "tempfile", "contextlib.contextmanager", "datetime.datetime", "pathlib.Path", "typing.Union", "ultralytics.YOLO", "ultralytics.nn.autobackend.default_class_names" ], "chunk_id": "function_file_date_0e5fa0e2" }, { "content": "def file_size(path: Union[str, Path]) -> float:\n \"\"\"Return the size of a file or directory in megabytes (MB).\"\"\"\n if isinstance(path, (str, Path)):\n mb = 1 << 20 # bytes to MiB (1024 ** 2)\n path = Path(path)\n if path.is_file():\n return path.stat().st_size / mb\n elif path.is_dir():\n return sum(f.stat().st_size for f in path.glob(\"**/*\") if f.is_file()) / mb\n return 0.0", "chunk_type": "function", "name": "file_size", "file_path": "ultralytics\\ultralytics\\utils\\files.py", "start_line": 168, "end_line": 177, "start_col": 0, "end_col": 14, "parent_name": null, "docstring": "Return the size of a file or directory in megabytes (MB).", "parameters": [ "path: Union[str, Path]" ], "return_type": "float", "decorators": [], "complexity_score": 5, "dependencies": [ "contextlib", "glob", "os", "shutil", "tempfile", "contextlib.contextmanager", "datetime.datetime", "pathlib.Path", "typing.Union", "ultralytics.YOLO", "ultralytics.nn.autobackend.default_class_names" ], "chunk_id": "function_file_size_04b9f1b2" }, { "content": "def get_latest_run(search_dir: str = \".\") -> str:\n \"\"\"Return the path to the most recent 'last.pt' file in the specified directory for resuming training.\"\"\"\n last_list = glob.glob(f\"{search_dir}/**/last*.pt\", recursive=True)\n return max(last_list, key=os.path.getctime) if last_list else \"\"", "chunk_type": "function", "name": "get_latest_run", "file_path": "ultralytics\\ultralytics\\utils\\files.py", "start_line": 180, "end_line": 183, "start_col": 0, "end_col": 68, "parent_name": null, "docstring": "Return the path to the most recent 'last.pt' file in the specified directory for resuming training.", "parameters": [ "search_dir: str" ], "return_type": "str", "decorators": [], "complexity_score": 1, "dependencies": [ "contextlib", "glob", "os", "shutil", "tempfile", "contextlib.contextmanager", "datetime.datetime", "pathlib.Path", "typing.Union", "ultralytics.YOLO", "ultralytics.nn.autobackend.default_class_names" ], "chunk_id": "function_get_latest_run_dcf5fd05" }, { "content": "def update_models(model_names: tuple = (\"yolo11n.pt\",), source_dir: Path = Path(\".\"), update_names: bool = False):\n \"\"\"\n Update and re-save specified YOLO models in an 'updated_models' subdirectory.\n\n Args:\n model_names (tuple, optional): Model filenames to update.\n source_dir (Path, optional): Directory containing models and target subdirectory.\n update_names (bool, optional): Update model names from a data YAML.\n\n Examples:\n Update specified YOLO models and save them in 'updated_models' subdirectory:\n >>> from ultralytics.utils.files import update_models\n >>> model_names = (\"yolo11n.pt\", \"yolov8s.pt\")\n >>> update_models(model_names, source_dir=Path(\"/models\"), update_names=True)\n \"\"\"\n from ultralytics import YOLO\n from ultralytics.nn.autobackend import default_class_names\n\n target_dir = source_dir / \"updated_models\"\n target_dir.mkdir(parents=True, exist_ok=True) # Ensure target directory exists\n\n for model_name in model_names:\n model_path = source_dir / model_name\n print(f\"Loading model from {model_path}\")\n\n # Load model\n model = YOLO(model_path)\n model.half()\n if update_names: # update model names from a dataset YAML\n model.model.names = default_class_names(\"coco8.yaml\")\n\n # Define new save path\n save_path = target_dir / model_name\n\n # Save model using model.save()\n print(f\"Re-saving {model_name} model to {save_path}\")\n model.save(save_path)", "chunk_type": "function", "name": "update_models", "file_path": "ultralytics\\ultralytics\\utils\\files.py", "start_line": 186, "end_line": 222, "start_col": 0, "end_col": 29, "parent_name": null, "docstring": "Update and re-save specified YOLO models in an 'updated_models' subdirectory.\n\nArgs:\n model_names (tuple, optional): Model filenames to update.\n source_dir (Path, optional): Directory containing models and target subdirectory.\n update_names (bool, optional): Update model names from a data YAML.\n\nExamples:\n Update specified YOLO models and save them in 'updated_models' subdirectory:\n >>> from ultralytics.utils.files import update_models\n >>> model_names = (\"yolo11n.pt\", \"yolov8s.pt\")\n >>> update_models(model_names, source_dir=Path(\"/models\"), update_names=True)", "parameters": [ "model_names: tuple", "source_dir: Path", "update_names: bool" ], "return_type": null, "decorators": [], "complexity_score": 3, "dependencies": [ "contextlib", "glob", "os", "shutil", "tempfile", "contextlib.contextmanager", "datetime.datetime", "pathlib.Path", "typing.Union", "ultralytics.YOLO", "ultralytics.nn.autobackend.default_class_names" ], "chunk_id": "function_update_models_b08d3edc" }, { "content": "from collections import abc", "chunk_type": "import", "name": "abc", "file_path": "ultralytics\\ultralytics\\utils\\instance.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 27, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_abc_1a55c307" }, { "content": "from itertools import repeat", "chunk_type": "import", "name": "repeat", "file_path": "ultralytics\\ultralytics\\utils\\instance.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 28, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_repeat_818933e2" }, { "content": "from numbers import Number", "chunk_type": "import", "name": "Number", "file_path": "ultralytics\\ultralytics\\utils\\instance.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 26, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Number_b9d16c87" }, { "content": "from typing import List, Union", "chunk_type": "import", "name": "List, Union", "file_path": "ultralytics\\ultralytics\\utils\\instance.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 30, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_List, Union_2df48fed" }, { "content": "import numpy as np", "chunk_type": "import", "name": "numpy", "file_path": "ultralytics\\ultralytics\\utils\\instance.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_numpy_a2efdae5" }, { "content": "from .ops import ltwh2xywh, ltwh2xyxy, resample_segments, xywh2ltwh, xywh2xyxy, xyxy2ltwh, xyxy2xywh", "chunk_type": "import", "name": "ltwh2xywh, ltwh2xyxy, resample_segments, xywh2ltwh, xywh2xyxy, xyxy2ltwh, xyxy2xywh", "file_path": "ultralytics\\ultralytics\\utils\\instance.py", "start_line": 10, "end_line": 10, "start_col": 0, "end_col": 100, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_ltwh2xywh, ltwh2xyxy, resample_segments, xywh2ltwh, xywh2xyxy, xyxy2ltwh, xyxy2xywh_9b94f05f" }, { "content": "def _ntuple(n):\n \"\"\"Create a function that converts input to n-tuple by repeating singleton values.\"\"\"\n\n def parse(x):\n \"\"\"Parse input to return n-tuple by repeating singleton values n times.\"\"\"\n return x if isinstance(x, abc.Iterable) else tuple(repeat(x, n))\n\n return parse", "chunk_type": "function", "name": "_ntuple", "file_path": "ultralytics\\ultralytics\\utils\\instance.py", "start_line": 13, "end_line": 20, "start_col": 0, "end_col": 16, "parent_name": null, "docstring": "Create a function that converts input to n-tuple by repeating singleton values.", "parameters": [ "n" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "collections.abc", "itertools.repeat", "numbers.Number", "typing.List", "typing.Union", "numpy", "ops.ltwh2xywh", "ops.ltwh2xyxy", "ops.resample_segments", "ops.xywh2ltwh", "ops.xywh2xyxy", "ops.xyxy2ltwh", "ops.xyxy2xywh" ], "chunk_id": "function__ntuple_833ec773" }, { "content": "to_2tuple = _ntuple(2)", "chunk_type": "variable", "name": "to_2tuple", "file_path": "ultralytics\\ultralytics\\utils\\instance.py", "start_line": 23, "end_line": 23, "start_col": 0, "end_col": 22, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_to_2tuple_0c3837fe" }, { "content": "to_4tuple = _ntuple(4)", "chunk_type": "variable", "name": "to_4tuple", "file_path": "ultralytics\\ultralytics\\utils\\instance.py", "start_line": 24, "end_line": 24, "start_col": 0, "end_col": 22, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_to_4tuple_a9f230e9" }, { "content": "_formats = [\"xyxy\", \"xywh\", \"ltwh\"]", "chunk_type": "variable", "name": "_formats", "file_path": "ultralytics\\ultralytics\\utils\\instance.py", "start_line": 29, "end_line": 29, "start_col": 0, "end_col": 35, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable__formats_736251a6" }, { "content": "__all__ = (\"Bboxes\", \"Instances\") # tuple or list", "chunk_type": "variable", "name": "__all__", "file_path": "ultralytics\\ultralytics\\utils\\instance.py", "start_line": 31, "end_line": 31, "start_col": 0, "end_col": 33, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable___all___bf54126f" }, { "content": "class Bboxes:\n \"\"\"\n A class for handling bounding boxes in multiple formats.\n\n The class supports various bounding box formats like 'xyxy', 'xywh', and 'ltwh' and provides methods for format\n conversion, scaling, and area calculation. Bounding box data should be provided as numpy arrays.\n\n Attributes:\n bboxes (np.ndarray): The bounding boxes stored in a 2D numpy array with shape (N, 4).\n format (str): The format of the bounding boxes ('xyxy', 'xywh', or 'ltwh').\n\n Methods:\n convert: Convert bounding box format from one type to another.\n areas: Calculate the area of bounding boxes.\n mul: Multiply bounding box coordinates by scale factor(s).\n add: Add offset to bounding box coordinates.\n concatenate: Concatenate multiple Bboxes objects.\n\n Examples:\n Create bounding boxes in YOLO format\n >>> bboxes = Bboxes(np.array([[100, 50, 150, 100]]), format=\"xywh\")\n >>> bboxes.convert(\"xyxy\")\n >>> print(bboxes.areas())\n\n Notes:\n This class does not handle normalization or denormalization of bounding boxes.\n \"\"\"\n\n def __init__(self, bboxes: np.ndarray, format: str = \"xyxy\") -> None:\n \"\"\"\n Initialize the Bboxes class with bounding box data in a specified format.\n\n Args:\n bboxes (np.ndarray): Array of bounding boxes with shape (N, 4) or (4,).\n format (str): Format of the bounding boxes, one of 'xyxy', 'xywh', or 'ltwh'.\n \"\"\"\n assert format in _formats, f\"Invalid bounding box format: {format}, format must be one of {_formats}\"\n bboxes = bboxes[None, :] if bboxes.ndim == 1 else bboxes\n assert bboxes.ndim == 2\n assert bboxes.shape[1] == 4\n self.bboxes = bboxes\n self.format = format\n\n def convert(self, format: str) -> None:\n \"\"\"\n Convert bounding box format from one type to another.\n\n Args:\n format (str): Target format for conversion, one of 'xyxy', 'xywh', or 'ltwh'.\n \"\"\"\n assert format in _formats, f\"Invalid bounding box format: {format}, format must be one of {_formats}\"\n if self.format == format:\n return\n elif self.format == \"xyxy\":\n func = xyxy2xywh if format == \"xywh\" else xyxy2ltwh\n elif self.format == \"xywh\":\n func = xywh2xyxy if format == \"xyxy\" else xywh2ltwh\n else:\n func = ltwh2xyxy if format == \"xyxy\" else ltwh2xywh\n self.bboxes = func(self.bboxes)\n self.format = format\n\n def areas(self) -> np.ndarray:\n \"\"\"Calculate the area of bounding boxes.\"\"\"\n return (\n (self.bboxes[:, 2] - self.bboxes[:, 0]) * (self.bboxes[:, 3] - self.bboxes[:, 1]) # format xyxy\n if self.format == \"xyxy\"\n else self.bboxes[:, 3] * self.bboxes[:, 2] # format xywh or ltwh\n )\n\n def mul(self, scale: Union[int, tuple, list]) -> None:\n \"\"\"\n Multiply bounding box coordinates by scale factor(s).\n\n Args:\n scale (int | tuple | list): Scale factor(s) for four coordinates. If int, the same scale is applied to\n all coordinates.\n \"\"\"\n if isinstance(scale, Number):\n scale = to_4tuple(scale)\n assert isinstance(scale, (tuple, list))\n assert len(scale) == 4\n self.bboxes[:, 0] *= scale[0]\n self.bboxes[:, 1] *= scale[1]\n self.bboxes[:, 2] *= scale[2]\n self.bboxes[:, 3] *= scale[3]\n\n def add(self, offset: Union[int, tuple, list]) -> None:\n \"\"\"\n Add offset to bounding box coordinates.\n\n Args:\n offset (int | tuple | list): Offset(s) for four coordinates. If int, the same offset is applied to\n all coordinates.\n \"\"\"\n if isinstance(offset, Number):\n offset = to_4tuple(offset)\n assert isinstance(offset, (tuple, list))\n assert len(offset) == 4\n self.bboxes[:, 0] += offset[0]\n self.bboxes[:, 1] += offset[1]\n self.bboxes[:, 2] += offset[2]\n self.bboxes[:, 3] += offset[3]\n\n def __len__(self) -> int:\n \"\"\"Return the number of bounding boxes.\"\"\"\n return len(self.bboxes)\n\n @classmethod\n def concatenate(cls, boxes_list: List[\"Bboxes\"], axis: int = 0) -> \"Bboxes\":\n \"\"\"\n Concatenate a list of Bboxes objects into a single Bboxes object.\n\n Args:\n boxes_list (List[Bboxes]): A list of Bboxes objects to concatenate.\n axis (int, optional): The axis along which to concatenate the bounding boxes.\n\n Returns:\n (Bboxes): A new Bboxes object containing the concatenated bounding boxes.\n\n Notes:\n The input should be a list or tuple of Bboxes objects.\n \"\"\"\n assert isinstance(boxes_list, (list, tuple))\n if not boxes_list:\n return cls(np.empty(0))\n assert all(isinstance(box, Bboxes) for box in boxes_list)\n\n if len(boxes_list) == 1:\n return boxes_list[0]\n return cls(np.concatenate([b.bboxes for b in boxes_list], axis=axis))\n\n def __getitem__(self, index: Union[int, np.ndarray, slice]) -> \"Bboxes\":\n \"\"\"\n Retrieve a specific bounding box or a set of bounding boxes using indexing.\n\n Args:\n index (int | slice | np.ndarray): The index, slice, or boolean array to select the desired bounding boxes.\n\n Returns:\n (Bboxes): A new Bboxes object containing the selected bounding boxes.\n\n Notes:\n When using boolean indexing, make sure to provide a boolean array with the same length as the number of\n bounding boxes.\n \"\"\"\n if isinstance(index, int):\n return Bboxes(self.bboxes[index].reshape(1, -1))\n b = self.bboxes[index]\n assert b.ndim == 2, f\"Indexing on Bboxes with {index} failed to return a matrix!\"\n return Bboxes(b)", "chunk_type": "class", "name": "Bboxes", "file_path": "ultralytics\\ultralytics\\utils\\instance.py", "start_line": 34, "end_line": 184, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": "A class for handling bounding boxes in multiple formats.\n\nThe class supports various bounding box formats like 'xyxy', 'xywh', and 'ltwh' and provides methods for format\nconversion, scaling, and area calculation. Bounding box data should be provided as numpy arrays.\n\nAttributes:\n bboxes (np.ndarray): The bounding boxes stored in a 2D numpy array with shape (N, 4).\n format (str): The format of the bounding boxes ('xyxy', 'xywh', or 'ltwh').\n\nMethods:\n convert: Convert bounding box format from one type to another.\n areas: Calculate the area of bounding boxes.\n mul: Multiply bounding box coordinates by scale factor(s).\n add: Add offset to bounding box coordinates.\n concatenate: Concatenate multiple Bboxes objects.\n\nExamples:\n Create bounding boxes in YOLO format\n >>> bboxes = Bboxes(np.array([[100, 50, 150, 100]]), format=\"xywh\")\n >>> bboxes.convert(\"xyxy\")\n >>> print(bboxes.areas())\n\nNotes:\n This class does not handle normalization or denormalization of bounding boxes.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "collections.abc", "itertools.repeat", "numbers.Number", "typing.List", "typing.Union", "numpy", "ops.ltwh2xywh", "ops.ltwh2xyxy", "ops.resample_segments", "ops.xywh2ltwh", "ops.xywh2xyxy", "ops.xyxy2ltwh", "ops.xyxy2xywh" ], "chunk_id": "class_Bboxes_8691b2b0" }, { "content": "class Instances:\n \"\"\"\n Container for bounding boxes, segments, and keypoints of detected objects in an image.\n\n This class provides a unified interface for handling different types of object annotations including bounding\n boxes, segmentation masks, and keypoints. It supports various operations like scaling, normalization, clipping,\n and format conversion.\n\n Attributes:\n _bboxes (Bboxes): Internal object for handling bounding box operations.\n keypoints (np.ndarray): Keypoints with shape (N, 17, 3) in format (x, y, visible).\n normalized (bool): Flag indicating whether the bounding box coordinates are normalized.\n segments (np.ndarray): Segments array with shape (N, M, 2) after resampling.\n\n Methods:\n convert_bbox: Convert bounding box format.\n scale: Scale coordinates by given factors.\n denormalize: Convert normalized coordinates to absolute coordinates.\n normalize: Convert absolute coordinates to normalized coordinates.\n add_padding: Add padding to coordinates.\n flipud: Flip coordinates vertically.\n fliplr: Flip coordinates horizontally.\n clip: Clip coordinates to stay within image boundaries.\n remove_zero_area_boxes: Remove boxes with zero area.\n update: Update instance variables.\n concatenate: Concatenate multiple Instances objects.\n\n Examples:\n Create instances with bounding boxes and segments\n >>> instances = Instances(\n ... bboxes=np.array([[10, 10, 30, 30], [20, 20, 40, 40]]),\n ... segments=[np.array([[5, 5], [10, 10]]), np.array([[15, 15], [20, 20]])],\n ... keypoints=np.array([[[5, 5, 1], [10, 10, 1]], [[15, 15, 1], [20, 20, 1]]]),\n ... )\n \"\"\"\n\n def __init__(\n self,\n bboxes: np.ndarray,\n segments: np.ndarray = None,\n keypoints: np.ndarray = None,\n bbox_format: str = \"xywh\",\n normalized: bool = True,\n ) -> None:\n \"\"\"\n Initialize the Instances object with bounding boxes, segments, and keypoints.\n\n Args:\n bboxes (np.ndarray): Bounding boxes with shape (N, 4).\n segments (np.ndarray, optional): Segmentation masks.\n keypoints (np.ndarray, optional): Keypoints with shape (N, 17, 3) in format (x, y, visible).\n bbox_format (str): Format of bboxes.\n normalized (bool): Whether the coordinates are normalized.\n \"\"\"\n self._bboxes = Bboxes(bboxes=bboxes, format=bbox_format)\n self.keypoints = keypoints\n self.normalized = normalized\n self.segments = segments\n\n def convert_bbox(self, format: str) -> None:\n \"\"\"\n Convert bounding box format.\n\n Args:\n format (str): Target format for conversion, one of 'xyxy', 'xywh', or 'ltwh'.\n \"\"\"\n self._bboxes.convert(format=format)\n\n @property\n def bbox_areas(self) -> np.ndarray:\n \"\"\"Calculate the area of bounding boxes.\"\"\"\n return self._bboxes.areas()\n\n def scale(self, scale_w: float, scale_h: float, bbox_only: bool = False):\n \"\"\"\n Scale coordinates by given factors.\n\n Args:\n scale_w (float): Scale factor for width.\n scale_h (float): Scale factor for height.\n bbox_only (bool, optional): Whether to scale only bounding boxes.\n \"\"\"\n self._bboxes.mul(scale=(scale_w, scale_h, scale_w, scale_h))\n if bbox_only:\n return\n self.segments[..., 0] *= scale_w\n self.segments[..., 1] *= scale_h\n if self.keypoints is not None:\n self.keypoints[..., 0] *= scale_w\n self.keypoints[..., 1] *= scale_h\n\n def denormalize(self, w: int, h: int) -> None:\n \"\"\"\n Convert normalized coordinates to absolute coordinates.\n\n Args:\n w (int): Image width.\n h (int): Image height.\n \"\"\"\n if not self.normalized:\n return\n self._bboxes.mul(scale=(w, h, w, h))\n self.segments[..., 0] *= w\n self.segments[..., 1] *= h\n if self.keypoints is not None:\n self.keypoints[..., 0] *= w\n self.keypoints[..., 1] *= h\n self.normalized = False\n\n def normalize(self, w: int, h: int) -> None:\n \"\"\"\n Convert absolute coordinates to normalized coordinates.\n\n Args:\n w (int): Image width.\n h (int): Image height.\n \"\"\"\n if self.normalized:\n return\n self._bboxes.mul(scale=(1 / w, 1 / h, 1 / w, 1 / h))\n self.segments[..., 0] /= w\n self.segments[..., 1] /= h\n if self.keypoints is not None:\n self.keypoints[..., 0] /= w\n self.keypoints[..., 1] /= h\n self.normalized = True\n\n def add_padding(self, padw: int, padh: int) -> None:\n \"\"\"\n Add padding to coordinates.\n\n Args:\n padw (int): Padding width.\n padh (int): Padding height.\n \"\"\"\n assert not self.normalized, \"you should add padding with absolute coordinates.\"\n self._bboxes.add(offset=(padw, padh, padw, padh))\n self.segments[..., 0] += padw\n self.segments[..., 1] += padh\n if self.keypoints is not None:\n self.keypoints[..., 0] += padw\n self.keypoints[..., 1] += padh\n\n def __getitem__(self, index: Union[int, np.ndarray, slice]) -> \"Instances\":\n \"\"\"\n Retrieve a specific instance or a set of instances using indexing.\n\n Args:\n index (int | slice | np.ndarray): The index, slice, or boolean array to select the desired instances.\n\n Returns:\n (Instances): A new Instances object containing the selected boxes, segments, and keypoints if present.\n\n Notes:\n When using boolean indexing, make sure to provide a boolean array with the same length as the number of\n instances.\n \"\"\"\n segments = self.segments[index] if len(self.segments) else self.segments\n keypoints = self.keypoints[index] if self.keypoints is not None else None\n bboxes = self.bboxes[index]\n bbox_format = self._bboxes.format\n return Instances(\n bboxes=bboxes,\n segments=segments,\n keypoints=keypoints,\n bbox_format=bbox_format,\n normalized=self.normalized,\n )\n\n def flipud(self, h: int) -> None:\n \"\"\"\n Flip coordinates vertically.\n\n Args:\n h (int): Image height.\n \"\"\"\n if self._bboxes.format == \"xyxy\":\n y1 = self.bboxes[:, 1].copy()\n y2 = self.bboxes[:, 3].copy()\n self.bboxes[:, 1] = h - y2\n self.bboxes[:, 3] = h - y1\n else:\n self.bboxes[:, 1] = h - self.bboxes[:, 1]\n self.segments[..., 1] = h - self.segments[..., 1]\n if self.keypoints is not None:\n self.keypoints[..., 1] = h - self.keypoints[..., 1]\n\n def fliplr(self, w: int) -> None:\n \"\"\"\n Flip coordinates horizontally.\n\n Args:\n w (int): Image width.\n \"\"\"\n if self._bboxes.format == \"xyxy\":\n x1 = self.bboxes[:, 0].copy()\n x2 = self.bboxes[:, 2].copy()\n self.bboxes[:, 0] = w - x2\n self.bboxes[:, 2] = w - x1\n else:\n self.bboxes[:, 0] = w - self.bboxes[:, 0]\n self.segments[..., 0] = w - self.segments[..., 0]\n if self.keypoints is not None:\n self.keypoints[..., 0] = w - self.keypoints[..., 0]\n\n def clip(self, w: int, h: int) -> None:\n \"\"\"\n Clip coordinates to stay within image boundaries.\n\n Args:\n w (int): Image width.\n h (int): Image height.\n \"\"\"\n ori_format = self._bboxes.format\n self.convert_bbox(format=\"xyxy\")\n self.bboxes[:, [0, 2]] = self.bboxes[:, [0, 2]].clip(0, w)\n self.bboxes[:, [1, 3]] = self.bboxes[:, [1, 3]].clip(0, h)\n if ori_format != \"xyxy\":\n self.convert_bbox(format=ori_format)\n self.segments[..., 0] = self.segments[..., 0].clip(0, w)\n self.segments[..., 1] = self.segments[..., 1].clip(0, h)\n if self.keypoints is not None:\n # Set out of bounds visibility to zero\n self.keypoints[..., 2][\n (self.keypoints[..., 0] < 0)\n | (self.keypoints[..., 0] > w)\n | (self.keypoints[..., 1] < 0)\n | (self.keypoints[..., 1] > h)\n ] = 0.0\n self.keypoints[..., 0] = self.keypoints[..., 0].clip(0, w)\n self.keypoints[..., 1] = self.keypoints[..., 1].clip(0, h)\n\n def remove_zero_area_boxes(self) -> np.ndarray:\n \"\"\"\n Remove zero-area boxes, i.e. after clipping some boxes may have zero width or height.\n\n Returns:\n (np.ndarray): Boolean array indicating which boxes were kept.\n \"\"\"\n good = self.bbox_areas > 0\n if not all(good):\n self._bboxes = self._bboxes[good]\n if len(self.segments):\n self.segments = self.segments[good]\n if self.keypoints is not None:\n self.keypoints = self.keypoints[good]\n return good\n\n def update(self, bboxes: np.ndarray, segments: np.ndarray = None, keypoints: np.ndarray = None):\n \"\"\"\n Update instance variables.\n\n Args:\n bboxes (np.ndarray): New bounding boxes.\n segments (np.ndarray, optional): New segments.\n keypoints (np.ndarray, optional): New keypoints.\n \"\"\"\n self._bboxes = Bboxes(bboxes, format=self._bboxes.format)\n if segments is not None:\n self.segments = segments\n if keypoints is not None:\n self.keypoints = keypoints\n\n def __len__(self) -> int:\n \"\"\"Return the number of instances.\"\"\"\n return len(self.bboxes)\n\n @classmethod\n def concatenate(cls, instances_list: List[\"Instances\"], axis=0) -> \"Instances\":\n \"\"\"\n Concatenate a list of Instances objects into a single Instances object.\n\n Args:\n instances_list (List[Instances]): A list of Instances objects to concatenate.\n axis (int, optional): The axis along which the arrays will be concatenated.\n\n Returns:\n (Instances): A new Instances object containing the concatenated bounding boxes, segments, and keypoints\n if present.\n\n Notes:\n The `Instances` objects in the list should have the same properties, such as the format of the bounding\n boxes, whether keypoints are present, and if the coordinates are normalized.\n \"\"\"\n assert isinstance(instances_list, (list, tuple))\n if not instances_list:\n return cls(np.empty(0))\n assert all(isinstance(instance, Instances) for instance in instances_list)\n\n if len(instances_list) == 1:\n return instances_list[0]\n\n use_keypoint = instances_list[0].keypoints is not None\n bbox_format = instances_list[0]._bboxes.format\n normalized = instances_list[0].normalized\n\n cat_boxes = np.concatenate([ins.bboxes for ins in instances_list], axis=axis)\n seg_len = [b.segments.shape[1] for b in instances_list]\n if len(frozenset(seg_len)) > 1: # resample segments if there's different length\n max_len = max(seg_len)\n cat_segments = np.concatenate(\n [\n resample_segments(list(b.segments), max_len)\n if len(b.segments)\n else np.zeros((0, max_len, 2), dtype=np.float32) # re-generating empty segments\n for b in instances_list\n ],\n axis=axis,\n )\n else:\n cat_segments = np.concatenate([b.segments for b in instances_list], axis=axis)\n cat_keypoints = np.concatenate([b.keypoints for b in instances_list], axis=axis) if use_keypoint else None\n return cls(cat_boxes, cat_segments, cat_keypoints, bbox_format, normalized)\n\n @property\n def bboxes(self) -> np.ndarray:\n \"\"\"Return bounding boxes.\"\"\"\n return self._bboxes.bboxes", "chunk_type": "class", "name": "Instances", "file_path": "ultralytics\\ultralytics\\utils\\instance.py", "start_line": 187, "end_line": 504, "start_col": 0, "end_col": 34, "parent_name": null, "docstring": "Container for bounding boxes, segments, and keypoints of detected objects in an image.\n\nThis class provides a unified interface for handling different types of object annotations including bounding\nboxes, segmentation masks, and keypoints. It supports various operations like scaling, normalization, clipping,\nand format conversion.\n\nAttributes:\n _bboxes (Bboxes): Internal object for handling bounding box operations.\n keypoints (np.ndarray): Keypoints with shape (N, 17, 3) in format (x, y, visible).\n normalized (bool): Flag indicating whether the bounding box coordinates are normalized.\n segments (np.ndarray): Segments array with shape (N, M, 2) after resampling.\n\nMethods:\n convert_bbox: Convert bounding box format.\n scale: Scale coordinates by given factors.\n denormalize: Convert normalized coordinates to absolute coordinates.\n normalize: Convert absolute coordinates to normalized coordinates.\n add_padding: Add padding to coordinates.\n flipud: Flip coordinates vertically.\n fliplr: Flip coordinates horizontally.\n clip: Clip coordinates to stay within image boundaries.\n remove_zero_area_boxes: Remove boxes with zero area.\n update: Update instance variables.\n concatenate: Concatenate multiple Instances objects.\n\nExamples:\n Create instances with bounding boxes and segments\n >>> instances = Instances(\n ... bboxes=np.array([[10, 10, 30, 30], [20, 20, 40, 40]]),\n ... segments=[np.array([[5, 5], [10, 10]]), np.array([[15, 15], [20, 20]])],\n ... keypoints=np.array([[[5, 5, 1], [10, 10, 1]], [[15, 15, 1], [20, 20, 1]]]),\n ... )", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "collections.abc", "itertools.repeat", "numbers.Number", "typing.List", "typing.Union", "numpy", "ops.ltwh2xywh", "ops.ltwh2xyxy", "ops.resample_segments", "ops.xywh2ltwh", "ops.xywh2xyxy", "ops.xyxy2ltwh", "ops.xyxy2xywh" ], "chunk_id": "class_Instances_76a2c85b" }, { "content": "from typing import Any, Dict, List, Tuple", "chunk_type": "import", "name": "Any, Dict, List, Tuple", "file_path": "ultralytics\\ultralytics\\utils\\loss.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 41, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Any, Dict, List, Tuple_0099e5e8" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\ultralytics\\utils\\loss.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_496c93cf" }, { "content": "import torch.nn as nn", "chunk_type": "import", "name": "torch.nn", "file_path": "ultralytics\\ultralytics\\utils\\loss.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 21, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch.nn_b28e90cb" }, { "content": "import torch.nn.functional as F", "chunk_type": "import", "name": "torch.nn.functional", "file_path": "ultralytics\\ultralytics\\utils\\loss.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 31, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch.nn.functional_9922e2b8" }, { "content": "from ultralytics.utils.metrics import OKS_SIGMA", "chunk_type": "import", "name": "OKS_SIGMA", "file_path": "ultralytics\\ultralytics\\utils\\loss.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 47, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_OKS_SIGMA_beb7ea8c" }, { "content": "from ultralytics.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh", "chunk_type": "import", "name": "crop_mask, xywh2xyxy, xyxy2xywh", "file_path": "ultralytics\\ultralytics\\utils\\loss.py", "start_line": 10, "end_line": 10, "start_col": 0, "end_col": 65, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_crop_mask, xywh2xyxy, xyxy2xywh_383a0b84" }, { "content": "from ultralytics.utils.tal import RotatedTaskAlignedAssigner, TaskAlignedAssigner, dist2bbox, dist2rbox, make_anchors", "chunk_type": "import", "name": "RotatedTaskAlignedAssigner, TaskAlignedAssigner, dist2bbox, dist2rbox, make_anchors", "file_path": "ultralytics\\ultralytics\\utils\\loss.py", "start_line": 11, "end_line": 11, "start_col": 0, "end_col": 117, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_RotatedTaskAlignedAssigner, TaskAlignedAssigner, dist2bbox, dist2rbox, make_anchors_1722eb43" }, { "content": "from ultralytics.utils.torch_utils import autocast", "chunk_type": "import", "name": "autocast", "file_path": "ultralytics\\ultralytics\\utils\\loss.py", "start_line": 12, "end_line": 12, "start_col": 0, "end_col": 50, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_autocast_4a645288" }, { "content": "from .metrics import bbox_iou, probiou", "chunk_type": "import", "name": "bbox_iou, probiou", "file_path": "ultralytics\\ultralytics\\utils\\loss.py", "start_line": 14, "end_line": 14, "start_col": 0, "end_col": 38, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_bbox_iou, probiou_3b940619" }, { "content": "from .tal import bbox2dist", "chunk_type": "import", "name": "bbox2dist", "file_path": "ultralytics\\ultralytics\\utils\\loss.py", "start_line": 15, "end_line": 15, "start_col": 0, "end_col": 26, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_bbox2dist_e462b29f" }, { "content": "class VarifocalLoss(nn.Module):\n \"\"\"\n Varifocal loss by Zhang et al.\n\n Implements the Varifocal Loss function for addressing class imbalance in object detection by focusing on\n hard-to-classify examples and balancing positive/negative samples.\n\n Attributes:\n gamma (float): The focusing parameter that controls how much the loss focuses on hard-to-classify examples.\n alpha (float): The balancing factor used to address class imbalance.\n\n References:\n https://arxiv.org/abs/2008.13367\n \"\"\"\n\n def __init__(self, gamma: float = 2.0, alpha: float = 0.75):\n \"\"\"Initialize the VarifocalLoss class with focusing and balancing parameters.\"\"\"\n super().__init__()\n self.gamma = gamma\n self.alpha = alpha\n\n def forward(self, pred_score: torch.Tensor, gt_score: torch.Tensor, label: torch.Tensor) -> torch.Tensor:\n \"\"\"Compute varifocal loss between predictions and ground truth.\"\"\"\n weight = self.alpha * pred_score.sigmoid().pow(self.gamma) * (1 - label) + gt_score * label\n with autocast(enabled=False):\n loss = (\n (F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction=\"none\") * weight)\n .mean(1)\n .sum()\n )\n return loss", "chunk_type": "class", "name": "VarifocalLoss", "file_path": "ultralytics\\ultralytics\\utils\\loss.py", "start_line": 18, "end_line": 48, "start_col": 0, "end_col": 19, "parent_name": null, "docstring": "Varifocal loss by Zhang et al.\n\nImplements the Varifocal Loss function for addressing class imbalance in object detection by focusing on\nhard-to-classify examples and balancing positive/negative samples.\n\nAttributes:\n gamma (float): The focusing parameter that controls how much the loss focuses on hard-to-classify examples.\n alpha (float): The balancing factor used to address class imbalance.\n\nReferences:\n https://arxiv.org/abs/2008.13367", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.metrics.OKS_SIGMA", "ultralytics.utils.ops.crop_mask", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xyxy2xywh", "ultralytics.utils.tal.RotatedTaskAlignedAssigner", "ultralytics.utils.tal.TaskAlignedAssigner", "ultralytics.utils.tal.dist2bbox", "ultralytics.utils.tal.dist2rbox", "ultralytics.utils.tal.make_anchors", "ultralytics.utils.torch_utils.autocast", "metrics.bbox_iou", "metrics.probiou", "tal.bbox2dist", "nn.Module" ], "chunk_id": "class_VarifocalLoss_0b82ffe4" }, { "content": "class FocalLoss(nn.Module):\n \"\"\"\n Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5).\n\n Implements the Focal Loss function for addressing class imbalance by down-weighting easy examples and focusing\n on hard negatives during training.\n\n Attributes:\n gamma (float): The focusing parameter that controls how much the loss focuses on hard-to-classify examples.\n alpha (torch.Tensor): The balancing factor used to address class imbalance.\n \"\"\"\n\n def __init__(self, gamma: float = 1.5, alpha: float = 0.25):\n \"\"\"Initialize FocalLoss class with focusing and balancing parameters.\"\"\"\n super().__init__()\n self.gamma = gamma\n self.alpha = torch.tensor(alpha)\n\n def forward(self, pred: torch.Tensor, label: torch.Tensor) -> torch.Tensor:\n \"\"\"Calculate focal loss with modulating factors for class imbalance.\"\"\"\n loss = F.binary_cross_entropy_with_logits(pred, label, reduction=\"none\")\n # p_t = torch.exp(-loss)\n # loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability\n\n # TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py\n pred_prob = pred.sigmoid() # prob from logits\n p_t = label * pred_prob + (1 - label) * (1 - pred_prob)\n modulating_factor = (1.0 - p_t) ** self.gamma\n loss *= modulating_factor\n if (self.alpha > 0).any():\n self.alpha = self.alpha.to(device=pred.device, dtype=pred.dtype)\n alpha_factor = label * self.alpha + (1 - label) * (1 - self.alpha)\n loss *= alpha_factor\n return loss.mean(1).sum()", "chunk_type": "class", "name": "FocalLoss", "file_path": "ultralytics\\ultralytics\\utils\\loss.py", "start_line": 51, "end_line": 84, "start_col": 0, "end_col": 33, "parent_name": null, "docstring": "Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5).\n\nImplements the Focal Loss function for addressing class imbalance by down-weighting easy examples and focusing\non hard negatives during training.\n\nAttributes:\n gamma (float): The focusing parameter that controls how much the loss focuses on hard-to-classify examples.\n alpha (torch.Tensor): The balancing factor used to address class imbalance.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.metrics.OKS_SIGMA", "ultralytics.utils.ops.crop_mask", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xyxy2xywh", "ultralytics.utils.tal.RotatedTaskAlignedAssigner", "ultralytics.utils.tal.TaskAlignedAssigner", "ultralytics.utils.tal.dist2bbox", "ultralytics.utils.tal.dist2rbox", "ultralytics.utils.tal.make_anchors", "ultralytics.utils.torch_utils.autocast", "metrics.bbox_iou", "metrics.probiou", "tal.bbox2dist", "nn.Module" ], "chunk_id": "class_FocalLoss_070521f6" }, { "content": "class DFLoss(nn.Module):\n \"\"\"Criterion class for computing Distribution Focal Loss (DFL).\"\"\"\n\n def __init__(self, reg_max: int = 16) -> None:\n \"\"\"Initialize the DFL module with regularization maximum.\"\"\"\n super().__init__()\n self.reg_max = reg_max\n\n def __call__(self, pred_dist: torch.Tensor, target: torch.Tensor) -> torch.Tensor:\n \"\"\"Return sum of left and right DFL losses from https://ieeexplore.ieee.org/document/9792391.\"\"\"\n target = target.clamp_(0, self.reg_max - 1 - 0.01)\n tl = target.long() # target left\n tr = tl + 1 # target right\n wl = tr - target # weight left\n wr = 1 - wl # weight right\n return (\n F.cross_entropy(pred_dist, tl.view(-1), reduction=\"none\").view(tl.shape) * wl\n + F.cross_entropy(pred_dist, tr.view(-1), reduction=\"none\").view(tl.shape) * wr\n ).mean(-1, keepdim=True)", "chunk_type": "class", "name": "DFLoss", "file_path": "ultralytics\\ultralytics\\utils\\loss.py", "start_line": 87, "end_line": 105, "start_col": 0, "end_col": 32, "parent_name": null, "docstring": "Criterion class for computing Distribution Focal Loss (DFL).", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.metrics.OKS_SIGMA", "ultralytics.utils.ops.crop_mask", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xyxy2xywh", "ultralytics.utils.tal.RotatedTaskAlignedAssigner", "ultralytics.utils.tal.TaskAlignedAssigner", "ultralytics.utils.tal.dist2bbox", "ultralytics.utils.tal.dist2rbox", "ultralytics.utils.tal.make_anchors", "ultralytics.utils.torch_utils.autocast", "metrics.bbox_iou", "metrics.probiou", "tal.bbox2dist", "nn.Module" ], "chunk_id": "class_DFLoss_77fc59e8" }, { "content": "class BboxLoss(nn.Module):\n \"\"\"Criterion class for computing training losses for bounding boxes.\"\"\"\n\n def __init__(self, reg_max: int = 16):\n \"\"\"Initialize the BboxLoss module with regularization maximum and DFL settings.\"\"\"\n super().__init__()\n self.dfl_loss = DFLoss(reg_max) if reg_max > 1 else None\n\n def forward(\n self,\n pred_dist: torch.Tensor,\n pred_bboxes: torch.Tensor,\n anchor_points: torch.Tensor,\n target_bboxes: torch.Tensor,\n target_scores: torch.Tensor,\n target_scores_sum: torch.Tensor,\n fg_mask: torch.Tensor,\n ) -> Tuple[torch.Tensor, torch.Tensor]:\n \"\"\"Compute IoU and DFL losses for bounding boxes.\"\"\"\n weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)\n iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True)\n loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum\n\n # DFL loss\n if self.dfl_loss:\n target_ltrb = bbox2dist(anchor_points, target_bboxes, self.dfl_loss.reg_max - 1)\n loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight\n loss_dfl = loss_dfl.sum() / target_scores_sum\n else:\n loss_dfl = torch.tensor(0.0).to(pred_dist.device)\n\n return loss_iou, loss_dfl", "chunk_type": "class", "name": "BboxLoss", "file_path": "ultralytics\\ultralytics\\utils\\loss.py", "start_line": 108, "end_line": 139, "start_col": 0, "end_col": 33, "parent_name": null, "docstring": "Criterion class for computing training losses for bounding boxes.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.metrics.OKS_SIGMA", "ultralytics.utils.ops.crop_mask", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xyxy2xywh", "ultralytics.utils.tal.RotatedTaskAlignedAssigner", "ultralytics.utils.tal.TaskAlignedAssigner", "ultralytics.utils.tal.dist2bbox", "ultralytics.utils.tal.dist2rbox", "ultralytics.utils.tal.make_anchors", "ultralytics.utils.torch_utils.autocast", "metrics.bbox_iou", "metrics.probiou", "tal.bbox2dist", "nn.Module" ], "chunk_id": "class_BboxLoss_1de7f202" }, { "content": "class RotatedBboxLoss(BboxLoss):\n \"\"\"Criterion class for computing training losses for rotated bounding boxes.\"\"\"\n\n def __init__(self, reg_max: int):\n \"\"\"Initialize the RotatedBboxLoss module with regularization maximum and DFL settings.\"\"\"\n super().__init__(reg_max)\n\n def forward(\n self,\n pred_dist: torch.Tensor,\n pred_bboxes: torch.Tensor,\n anchor_points: torch.Tensor,\n target_bboxes: torch.Tensor,\n target_scores: torch.Tensor,\n target_scores_sum: torch.Tensor,\n fg_mask: torch.Tensor,\n ) -> Tuple[torch.Tensor, torch.Tensor]:\n \"\"\"Compute IoU and DFL losses for rotated bounding boxes.\"\"\"\n weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)\n iou = probiou(pred_bboxes[fg_mask], target_bboxes[fg_mask])\n loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum\n\n # DFL loss\n if self.dfl_loss:\n target_ltrb = bbox2dist(anchor_points, xywh2xyxy(target_bboxes[..., :4]), self.dfl_loss.reg_max - 1)\n loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight\n loss_dfl = loss_dfl.sum() / target_scores_sum\n else:\n loss_dfl = torch.tensor(0.0).to(pred_dist.device)\n\n return loss_iou, loss_dfl", "chunk_type": "class", "name": "RotatedBboxLoss", "file_path": "ultralytics\\ultralytics\\utils\\loss.py", "start_line": 142, "end_line": 172, "start_col": 0, "end_col": 33, "parent_name": null, "docstring": "Criterion class for computing training losses for rotated bounding boxes.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.metrics.OKS_SIGMA", "ultralytics.utils.ops.crop_mask", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xyxy2xywh", "ultralytics.utils.tal.RotatedTaskAlignedAssigner", "ultralytics.utils.tal.TaskAlignedAssigner", "ultralytics.utils.tal.dist2bbox", "ultralytics.utils.tal.dist2rbox", "ultralytics.utils.tal.make_anchors", "ultralytics.utils.torch_utils.autocast", "metrics.bbox_iou", "metrics.probiou", "tal.bbox2dist", "BboxLoss" ], "chunk_id": "class_RotatedBboxLoss_0b2451d8" }, { "content": "class KeypointLoss(nn.Module):\n \"\"\"Criterion class for computing keypoint losses.\"\"\"\n\n def __init__(self, sigmas: torch.Tensor) -> None:\n \"\"\"Initialize the KeypointLoss class with keypoint sigmas.\"\"\"\n super().__init__()\n self.sigmas = sigmas\n\n def forward(\n self, pred_kpts: torch.Tensor, gt_kpts: torch.Tensor, kpt_mask: torch.Tensor, area: torch.Tensor\n ) -> torch.Tensor:\n \"\"\"Calculate keypoint loss factor and Euclidean distance loss for keypoints.\"\"\"\n d = (pred_kpts[..., 0] - gt_kpts[..., 0]).pow(2) + (pred_kpts[..., 1] - gt_kpts[..., 1]).pow(2)\n kpt_loss_factor = kpt_mask.shape[1] / (torch.sum(kpt_mask != 0, dim=1) + 1e-9)\n # e = d / (2 * (area * self.sigmas) ** 2 + 1e-9) # from formula\n e = d / ((2 * self.sigmas).pow(2) * (area + 1e-9) * 2) # from cocoeval\n return (kpt_loss_factor.view(-1, 1) * ((1 - torch.exp(-e)) * kpt_mask)).mean()", "chunk_type": "class", "name": "KeypointLoss", "file_path": "ultralytics\\ultralytics\\utils\\loss.py", "start_line": 175, "end_line": 191, "start_col": 0, "end_col": 86, "parent_name": null, "docstring": "Criterion class for computing keypoint losses.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.metrics.OKS_SIGMA", "ultralytics.utils.ops.crop_mask", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xyxy2xywh", "ultralytics.utils.tal.RotatedTaskAlignedAssigner", "ultralytics.utils.tal.TaskAlignedAssigner", "ultralytics.utils.tal.dist2bbox", "ultralytics.utils.tal.dist2rbox", "ultralytics.utils.tal.make_anchors", "ultralytics.utils.torch_utils.autocast", "metrics.bbox_iou", "metrics.probiou", "tal.bbox2dist", "nn.Module" ], "chunk_id": "class_KeypointLoss_a5a78e11" }, { "content": "class v8DetectionLoss:\n \"\"\"Criterion class for computing training losses for YOLOv8 object detection.\"\"\"\n\n def __init__(self, model, tal_topk: int = 10): # model must be de-paralleled\n \"\"\"Initialize v8DetectionLoss with model parameters and task-aligned assignment settings.\"\"\"\n device = next(model.parameters()).device # get model device\n h = model.args # hyperparameters\n\n m = model.model[-1] # Detect() module\n self.bce = nn.BCEWithLogitsLoss(reduction=\"none\")\n self.hyp = h\n self.stride = m.stride # model strides\n self.nc = m.nc # number of classes\n self.no = m.nc + m.reg_max * 4\n self.reg_max = m.reg_max\n self.device = device\n\n self.use_dfl = m.reg_max > 1\n\n self.assigner = TaskAlignedAssigner(topk=tal_topk, num_classes=self.nc, alpha=0.5, beta=6.0)\n self.bbox_loss = BboxLoss(m.reg_max).to(device)\n self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device)\n\n def preprocess(self, targets: torch.Tensor, batch_size: int, scale_tensor: torch.Tensor) -> torch.Tensor:\n \"\"\"Preprocess targets by converting to tensor format and scaling coordinates.\"\"\"\n nl, ne = targets.shape\n if nl == 0:\n out = torch.zeros(batch_size, 0, ne - 1, device=self.device)\n else:\n i = targets[:, 0] # image index\n _, counts = i.unique(return_counts=True)\n counts = counts.to(dtype=torch.int32)\n out = torch.zeros(batch_size, counts.max(), ne - 1, device=self.device)\n for j in range(batch_size):\n matches = i == j\n if n := matches.sum():\n out[j, :n] = targets[matches, 1:]\n out[..., 1:5] = xywh2xyxy(out[..., 1:5].mul_(scale_tensor))\n return out\n\n def bbox_decode(self, anchor_points: torch.Tensor, pred_dist: torch.Tensor) -> torch.Tensor:\n \"\"\"Decode predicted object bounding box coordinates from anchor points and distribution.\"\"\"\n if self.use_dfl:\n b, a, c = pred_dist.shape # batch, anchors, channels\n pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))\n # pred_dist = pred_dist.view(b, a, c // 4, 4).transpose(2,3).softmax(3).matmul(self.proj.type(pred_dist.dtype))\n # pred_dist = (pred_dist.view(b, a, c // 4, 4).softmax(2) * self.proj.type(pred_dist.dtype).view(1, 1, -1, 1)).sum(2)\n return dist2bbox(pred_dist, anchor_points, xywh=False)\n\n def __call__(self, preds: Any, batch: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:\n \"\"\"Calculate the sum of the loss for box, cls and dfl multiplied by batch size.\"\"\"\n loss = torch.zeros(3, device=self.device) # box, cls, dfl\n feats = preds[1] if isinstance(preds, tuple) else preds\n pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(\n (self.reg_max * 4, self.nc), 1\n )\n\n pred_scores = pred_scores.permute(0, 2, 1).contiguous()\n pred_distri = pred_distri.permute(0, 2, 1).contiguous()\n\n dtype = pred_scores.dtype\n batch_size = pred_scores.shape[0]\n imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)\n anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)\n\n # Targets\n targets = torch.cat((batch[\"batch_idx\"].view(-1, 1), batch[\"cls\"].view(-1, 1), batch[\"bboxes\"]), 1)\n targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])\n gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy\n mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)\n\n # Pboxes\n pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)\n # dfl_conf = pred_distri.view(batch_size, -1, 4, self.reg_max).detach().softmax(-1)\n # dfl_conf = (dfl_conf.amax(-1).mean(-1) + dfl_conf.amax(-1).amin(-1)) / 2\n\n _, target_bboxes, target_scores, fg_mask, _ = self.assigner(\n # pred_scores.detach().sigmoid() * 0.8 + dfl_conf.unsqueeze(-1) * 0.2,\n pred_scores.detach().sigmoid(),\n (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),\n anchor_points * stride_tensor,\n gt_labels,\n gt_bboxes,\n mask_gt,\n )\n\n target_scores_sum = max(target_scores.sum(), 1)\n\n # Cls loss\n # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way\n loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE\n\n # Bbox loss\n if fg_mask.sum():\n target_bboxes /= stride_tensor\n loss[0], loss[2] = self.bbox_loss(\n pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask\n )\n\n loss[0] *= self.hyp.box # box gain\n loss[1] *= self.hyp.cls # cls gain\n loss[2] *= self.hyp.dfl # dfl gain\n\n return loss * batch_size, loss.detach() # loss(box, cls, dfl)", "chunk_type": "class", "name": "v8DetectionLoss", "file_path": "ultralytics\\ultralytics\\utils\\loss.py", "start_line": 194, "end_line": 297, "start_col": 0, "end_col": 47, "parent_name": null, "docstring": "Criterion class for computing training losses for YOLOv8 object detection.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.metrics.OKS_SIGMA", "ultralytics.utils.ops.crop_mask", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xyxy2xywh", "ultralytics.utils.tal.RotatedTaskAlignedAssigner", "ultralytics.utils.tal.TaskAlignedAssigner", "ultralytics.utils.tal.dist2bbox", "ultralytics.utils.tal.dist2rbox", "ultralytics.utils.tal.make_anchors", "ultralytics.utils.torch_utils.autocast", "metrics.bbox_iou", "metrics.probiou", "tal.bbox2dist" ], "chunk_id": "class_v8DetectionLoss_0f5de2a7" }, { "content": "class v8SegmentationLoss(v8DetectionLoss):\n \"\"\"Criterion class for computing training losses for YOLOv8 segmentation.\"\"\"\n\n def __init__(self, model): # model must be de-paralleled\n \"\"\"Initialize the v8SegmentationLoss class with model parameters and mask overlap setting.\"\"\"\n super().__init__(model)\n self.overlap = model.args.overlap_mask\n\n def __call__(self, preds: Any, batch: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:\n \"\"\"Calculate and return the combined loss for detection and segmentation.\"\"\"\n loss = torch.zeros(4, device=self.device) # box, seg, cls, dfl\n feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]\n batch_size, _, mask_h, mask_w = proto.shape # batch size, number of masks, mask height, mask width\n pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(\n (self.reg_max * 4, self.nc), 1\n )\n\n # B, grids, ..\n pred_scores = pred_scores.permute(0, 2, 1).contiguous()\n pred_distri = pred_distri.permute(0, 2, 1).contiguous()\n pred_masks = pred_masks.permute(0, 2, 1).contiguous()\n\n dtype = pred_scores.dtype\n imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)\n anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)\n\n # Targets\n try:\n batch_idx = batch[\"batch_idx\"].view(-1, 1)\n targets = torch.cat((batch_idx, batch[\"cls\"].view(-1, 1), batch[\"bboxes\"]), 1)\n targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])\n gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy\n mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)\n except RuntimeError as e:\n raise TypeError(\n \"ERROR ❌ segment dataset incorrectly formatted or not a segment dataset.\\n\"\n \"This error can occur when incorrectly training a 'segment' model on a 'detect' dataset, \"\n \"i.e. 'yolo train model=yolo11n-seg.pt data=coco8.yaml'.\\nVerify your dataset is a \"\n \"correctly formatted 'segment' dataset using 'data=coco8-seg.yaml' \"\n \"as an example.\\nSee https://docs.ultralytics.com/datasets/segment/ for help.\"\n ) from e\n\n # Pboxes\n pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)\n\n _, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(\n pred_scores.detach().sigmoid(),\n (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),\n anchor_points * stride_tensor,\n gt_labels,\n gt_bboxes,\n mask_gt,\n )\n\n target_scores_sum = max(target_scores.sum(), 1)\n\n # Cls loss\n # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way\n loss[2] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE\n\n if fg_mask.sum():\n # Bbox loss\n loss[0], loss[3] = self.bbox_loss(\n pred_distri,\n pred_bboxes,\n anchor_points,\n target_bboxes / stride_tensor,\n target_scores,\n target_scores_sum,\n fg_mask,\n )\n # Masks loss\n masks = batch[\"masks\"].to(self.device).float()\n if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsample\n masks = F.interpolate(masks[None], (mask_h, mask_w), mode=\"nearest\")[0]\n\n loss[1] = self.calculate_segmentation_loss(\n fg_mask, masks, target_gt_idx, target_bboxes, batch_idx, proto, pred_masks, imgsz, self.overlap\n )\n\n # WARNING: lines below prevent Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove\n else:\n loss[1] += (proto * 0).sum() + (pred_masks * 0).sum() # inf sums may lead to nan loss\n\n loss[0] *= self.hyp.box # box gain\n loss[1] *= self.hyp.box # seg gain\n loss[2] *= self.hyp.cls # cls gain\n loss[3] *= self.hyp.dfl # dfl gain\n\n return loss * batch_size, loss.detach() # loss(box, cls, dfl)\n\n @staticmethod\n def single_mask_loss(\n gt_mask: torch.Tensor, pred: torch.Tensor, proto: torch.Tensor, xyxy: torch.Tensor, area: torch.Tensor\n ) -> torch.Tensor:\n \"\"\"\n Compute the instance segmentation loss for a single image.\n\n Args:\n gt_mask (torch.Tensor): Ground truth mask of shape (N, H, W), where N is the number of objects.\n pred (torch.Tensor): Predicted mask coefficients of shape (N, 32).\n proto (torch.Tensor): Prototype masks of shape (32, H, W).\n xyxy (torch.Tensor): Ground truth bounding boxes in xyxy format, normalized to [0, 1], of shape (N, 4).\n area (torch.Tensor): Area of each ground truth bounding box of shape (N,).\n\n Returns:\n (torch.Tensor): The calculated mask loss for a single image.\n\n Notes:\n The function uses the equation pred_mask = torch.einsum('in,nhw->ihw', pred, proto) to produce the\n predicted masks from the prototype masks and predicted mask coefficients.\n \"\"\"\n pred_mask = torch.einsum(\"in,nhw->ihw\", pred, proto) # (n, 32) @ (32, 80, 80) -> (n, 80, 80)\n loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction=\"none\")\n return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).sum()\n\n def calculate_segmentation_loss(\n self,\n fg_mask: torch.Tensor,\n masks: torch.Tensor,\n target_gt_idx: torch.Tensor,\n target_bboxes: torch.Tensor,\n batch_idx: torch.Tensor,\n proto: torch.Tensor,\n pred_masks: torch.Tensor,\n imgsz: torch.Tensor,\n overlap: bool,\n ) -> torch.Tensor:\n \"\"\"\n Calculate the loss for instance segmentation.\n\n Args:\n fg_mask (torch.Tensor): A binary tensor of shape (BS, N_anchors) indicating which anchors are positive.\n masks (torch.Tensor): Ground truth masks of shape (BS, H, W) if `overlap` is False, otherwise (BS, ?, H, W).\n target_gt_idx (torch.Tensor): Indexes of ground truth objects for each anchor of shape (BS, N_anchors).\n target_bboxes (torch.Tensor): Ground truth bounding boxes for each anchor of shape (BS, N_anchors, 4).\n batch_idx (torch.Tensor): Batch indices of shape (N_labels_in_batch, 1).\n proto (torch.Tensor): Prototype masks of shape (BS, 32, H, W).\n pred_masks (torch.Tensor): Predicted masks for each anchor of shape (BS, N_anchors, 32).\n imgsz (torch.Tensor): Size of the input image as a tensor of shape (2), i.e., (H, W).\n overlap (bool): Whether the masks in `masks` tensor overlap.\n\n Returns:\n (torch.Tensor): The calculated loss for instance segmentation.\n\n Notes:\n The batch loss can be computed for improved speed at higher memory usage.\n For example, pred_mask can be computed as follows:\n pred_mask = torch.einsum('in,nhw->ihw', pred, proto) # (i, 32) @ (32, 160, 160) -> (i, 160, 160)\n \"\"\"\n _, _, mask_h, mask_w = proto.shape\n loss = 0\n\n # Normalize to 0-1\n target_bboxes_normalized = target_bboxes / imgsz[[1, 0, 1, 0]]\n\n # Areas of target bboxes\n marea = xyxy2xywh(target_bboxes_normalized)[..., 2:].prod(2)\n\n # Normalize to mask size\n mxyxy = target_bboxes_normalized * torch.tensor([mask_w, mask_h, mask_w, mask_h], device=proto.device)\n\n for i, single_i in enumerate(zip(fg_mask, target_gt_idx, pred_masks, proto, mxyxy, marea, masks)):\n fg_mask_i, target_gt_idx_i, pred_masks_i, proto_i, mxyxy_i, marea_i, masks_i = single_i\n if fg_mask_i.any():\n mask_idx = target_gt_idx_i[fg_mask_i]\n if overlap:\n gt_mask = masks_i == (mask_idx + 1).view(-1, 1, 1)\n gt_mask = gt_mask.float()\n else:\n gt_mask = masks[batch_idx.view(-1) == i][mask_idx]\n\n loss += self.single_mask_loss(\n gt_mask, pred_masks_i[fg_mask_i], proto_i, mxyxy_i[fg_mask_i], marea_i[fg_mask_i]\n )\n\n # WARNING: lines below prevents Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove\n else:\n loss += (proto * 0).sum() + (pred_masks * 0).sum() # inf sums may lead to nan loss\n\n return loss / fg_mask.sum()", "chunk_type": "class", "name": "v8SegmentationLoss", "file_path": "ultralytics\\ultralytics\\utils\\loss.py", "start_line": 300, "end_line": 480, "start_col": 0, "end_col": 35, "parent_name": null, "docstring": "Criterion class for computing training losses for YOLOv8 segmentation.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.metrics.OKS_SIGMA", "ultralytics.utils.ops.crop_mask", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xyxy2xywh", "ultralytics.utils.tal.RotatedTaskAlignedAssigner", "ultralytics.utils.tal.TaskAlignedAssigner", "ultralytics.utils.tal.dist2bbox", "ultralytics.utils.tal.dist2rbox", "ultralytics.utils.tal.make_anchors", "ultralytics.utils.torch_utils.autocast", "metrics.bbox_iou", "metrics.probiou", "tal.bbox2dist", "v8DetectionLoss" ], "chunk_id": "class_v8SegmentationLoss_3732945c" }, { "content": "class v8PoseLoss(v8DetectionLoss):\n \"\"\"Criterion class for computing training losses for YOLOv8 pose estimation.\"\"\"\n\n def __init__(self, model): # model must be de-paralleled\n \"\"\"Initialize v8PoseLoss with model parameters and keypoint-specific loss functions.\"\"\"\n super().__init__(model)\n self.kpt_shape = model.model[-1].kpt_shape\n self.bce_pose = nn.BCEWithLogitsLoss()\n is_pose = self.kpt_shape == [17, 3]\n nkpt = self.kpt_shape[0] # number of keypoints\n sigmas = torch.from_numpy(OKS_SIGMA).to(self.device) if is_pose else torch.ones(nkpt, device=self.device) / nkpt\n self.keypoint_loss = KeypointLoss(sigmas=sigmas)\n\n def __call__(self, preds: Any, batch: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:\n \"\"\"Calculate the total loss and detach it for pose estimation.\"\"\"\n loss = torch.zeros(5, device=self.device) # box, cls, dfl, kpt_location, kpt_visibility\n feats, pred_kpts = preds if isinstance(preds[0], list) else preds[1]\n pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(\n (self.reg_max * 4, self.nc), 1\n )\n\n # B, grids, ..\n pred_scores = pred_scores.permute(0, 2, 1).contiguous()\n pred_distri = pred_distri.permute(0, 2, 1).contiguous()\n pred_kpts = pred_kpts.permute(0, 2, 1).contiguous()\n\n dtype = pred_scores.dtype\n imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)\n anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)\n\n # Targets\n batch_size = pred_scores.shape[0]\n batch_idx = batch[\"batch_idx\"].view(-1, 1)\n targets = torch.cat((batch_idx, batch[\"cls\"].view(-1, 1), batch[\"bboxes\"]), 1)\n targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])\n gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy\n mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)\n\n # Pboxes\n pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)\n pred_kpts = self.kpts_decode(anchor_points, pred_kpts.view(batch_size, -1, *self.kpt_shape)) # (b, h*w, 17, 3)\n\n _, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(\n pred_scores.detach().sigmoid(),\n (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),\n anchor_points * stride_tensor,\n gt_labels,\n gt_bboxes,\n mask_gt,\n )\n\n target_scores_sum = max(target_scores.sum(), 1)\n\n # Cls loss\n # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way\n loss[3] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE\n\n # Bbox loss\n if fg_mask.sum():\n target_bboxes /= stride_tensor\n loss[0], loss[4] = self.bbox_loss(\n pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask\n )\n keypoints = batch[\"keypoints\"].to(self.device).float().clone()\n keypoints[..., 0] *= imgsz[1]\n keypoints[..., 1] *= imgsz[0]\n\n loss[1], loss[2] = self.calculate_keypoints_loss(\n fg_mask, target_gt_idx, keypoints, batch_idx, stride_tensor, target_bboxes, pred_kpts\n )\n\n loss[0] *= self.hyp.box # box gain\n loss[1] *= self.hyp.pose # pose gain\n loss[2] *= self.hyp.kobj # kobj gain\n loss[3] *= self.hyp.cls # cls gain\n loss[4] *= self.hyp.dfl # dfl gain\n\n return loss * batch_size, loss.detach() # loss(box, cls, dfl)\n\n @staticmethod\n def kpts_decode(anchor_points: torch.Tensor, pred_kpts: torch.Tensor) -> torch.Tensor:\n \"\"\"Decode predicted keypoints to image coordinates.\"\"\"\n y = pred_kpts.clone()\n y[..., :2] *= 2.0\n y[..., 0] += anchor_points[:, [0]] - 0.5\n y[..., 1] += anchor_points[:, [1]] - 0.5\n return y\n\n def calculate_keypoints_loss(\n self,\n masks: torch.Tensor,\n target_gt_idx: torch.Tensor,\n keypoints: torch.Tensor,\n batch_idx: torch.Tensor,\n stride_tensor: torch.Tensor,\n target_bboxes: torch.Tensor,\n pred_kpts: torch.Tensor,\n ) -> Tuple[torch.Tensor, torch.Tensor]:\n \"\"\"\n Calculate the keypoints loss for the model.\n\n This function calculates the keypoints loss and keypoints object loss for a given batch. The keypoints loss is\n based on the difference between the predicted keypoints and ground truth keypoints. The keypoints object loss is\n a binary classification loss that classifies whether a keypoint is present or not.\n\n Args:\n masks (torch.Tensor): Binary mask tensor indicating object presence, shape (BS, N_anchors).\n target_gt_idx (torch.Tensor): Index tensor mapping anchors to ground truth objects, shape (BS, N_anchors).\n keypoints (torch.Tensor): Ground truth keypoints, shape (N_kpts_in_batch, N_kpts_per_object, kpts_dim).\n batch_idx (torch.Tensor): Batch index tensor for keypoints, shape (N_kpts_in_batch, 1).\n stride_tensor (torch.Tensor): Stride tensor for anchors, shape (N_anchors, 1).\n target_bboxes (torch.Tensor): Ground truth boxes in (x1, y1, x2, y2) format, shape (BS, N_anchors, 4).\n pred_kpts (torch.Tensor): Predicted keypoints, shape (BS, N_anchors, N_kpts_per_object, kpts_dim).\n\n Returns:\n kpts_loss (torch.Tensor): The keypoints loss.\n kpts_obj_loss (torch.Tensor): The keypoints object loss.\n \"\"\"\n batch_idx = batch_idx.flatten()\n batch_size = len(masks)\n\n # Find the maximum number of keypoints in a single image\n max_kpts = torch.unique(batch_idx, return_counts=True)[1].max()\n\n # Create a tensor to hold batched keypoints\n batched_keypoints = torch.zeros(\n (batch_size, max_kpts, keypoints.shape[1], keypoints.shape[2]), device=keypoints.device\n )\n\n # TODO: any idea how to vectorize this?\n # Fill batched_keypoints with keypoints based on batch_idx\n for i in range(batch_size):\n keypoints_i = keypoints[batch_idx == i]\n batched_keypoints[i, : keypoints_i.shape[0]] = keypoints_i\n\n # Expand dimensions of target_gt_idx to match the shape of batched_keypoints\n target_gt_idx_expanded = target_gt_idx.unsqueeze(-1).unsqueeze(-1)\n\n # Use target_gt_idx_expanded to select keypoints from batched_keypoints\n selected_keypoints = batched_keypoints.gather(\n 1, target_gt_idx_expanded.expand(-1, -1, keypoints.shape[1], keypoints.shape[2])\n )\n\n # Divide coordinates by stride\n selected_keypoints[..., :2] /= stride_tensor.view(1, -1, 1, 1)\n\n kpts_loss = 0\n kpts_obj_loss = 0\n\n if masks.any():\n gt_kpt = selected_keypoints[masks]\n area = xyxy2xywh(target_bboxes[masks])[:, 2:].prod(1, keepdim=True)\n pred_kpt = pred_kpts[masks]\n kpt_mask = gt_kpt[..., 2] != 0 if gt_kpt.shape[-1] == 3 else torch.full_like(gt_kpt[..., 0], True)\n kpts_loss = self.keypoint_loss(pred_kpt, gt_kpt, kpt_mask, area) # pose loss\n\n if pred_kpt.shape[-1] == 3:\n kpts_obj_loss = self.bce_pose(pred_kpt[..., 2], kpt_mask.float()) # keypoint obj loss\n\n return kpts_loss, kpts_obj_loss", "chunk_type": "class", "name": "v8PoseLoss", "file_path": "ultralytics\\ultralytics\\utils\\loss.py", "start_line": 483, "end_line": 642, "start_col": 0, "end_col": 39, "parent_name": null, "docstring": "Criterion class for computing training losses for YOLOv8 pose estimation.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.metrics.OKS_SIGMA", "ultralytics.utils.ops.crop_mask", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xyxy2xywh", "ultralytics.utils.tal.RotatedTaskAlignedAssigner", "ultralytics.utils.tal.TaskAlignedAssigner", "ultralytics.utils.tal.dist2bbox", "ultralytics.utils.tal.dist2rbox", "ultralytics.utils.tal.make_anchors", "ultralytics.utils.torch_utils.autocast", "metrics.bbox_iou", "metrics.probiou", "tal.bbox2dist", "v8DetectionLoss" ], "chunk_id": "class_v8PoseLoss_298674cd" }, { "content": "class v8ClassificationLoss:\n \"\"\"Criterion class for computing training losses for classification.\"\"\"\n\n def __call__(self, preds: Any, batch: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:\n \"\"\"Compute the classification loss between predictions and true labels.\"\"\"\n preds = preds[1] if isinstance(preds, (list, tuple)) else preds\n loss = F.cross_entropy(preds, batch[\"cls\"], reduction=\"mean\")\n return loss, loss.detach()", "chunk_type": "class", "name": "v8ClassificationLoss", "file_path": "ultralytics\\ultralytics\\utils\\loss.py", "start_line": 645, "end_line": 652, "start_col": 0, "end_col": 34, "parent_name": null, "docstring": "Criterion class for computing training losses for classification.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.metrics.OKS_SIGMA", "ultralytics.utils.ops.crop_mask", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xyxy2xywh", "ultralytics.utils.tal.RotatedTaskAlignedAssigner", "ultralytics.utils.tal.TaskAlignedAssigner", "ultralytics.utils.tal.dist2bbox", "ultralytics.utils.tal.dist2rbox", "ultralytics.utils.tal.make_anchors", "ultralytics.utils.torch_utils.autocast", "metrics.bbox_iou", "metrics.probiou", "tal.bbox2dist" ], "chunk_id": "class_v8ClassificationLoss_649f5dc3" }, { "content": "class v8OBBLoss(v8DetectionLoss):\n \"\"\"Calculates losses for object detection, classification, and box distribution in rotated YOLO models.\"\"\"\n\n def __init__(self, model):\n \"\"\"Initialize v8OBBLoss with model, assigner, and rotated bbox loss; model must be de-paralleled.\"\"\"\n super().__init__(model)\n self.assigner = RotatedTaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0)\n self.bbox_loss = RotatedBboxLoss(self.reg_max).to(self.device)\n\n def preprocess(self, targets: torch.Tensor, batch_size: int, scale_tensor: torch.Tensor) -> torch.Tensor:\n \"\"\"Preprocess targets for oriented bounding box detection.\"\"\"\n if targets.shape[0] == 0:\n out = torch.zeros(batch_size, 0, 6, device=self.device)\n else:\n i = targets[:, 0] # image index\n _, counts = i.unique(return_counts=True)\n counts = counts.to(dtype=torch.int32)\n out = torch.zeros(batch_size, counts.max(), 6, device=self.device)\n for j in range(batch_size):\n matches = i == j\n if n := matches.sum():\n bboxes = targets[matches, 2:]\n bboxes[..., :4].mul_(scale_tensor)\n out[j, :n] = torch.cat([targets[matches, 1:2], bboxes], dim=-1)\n return out\n\n def __call__(self, preds: Any, batch: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:\n \"\"\"Calculate and return the loss for oriented bounding box detection.\"\"\"\n loss = torch.zeros(3, device=self.device) # box, cls, dfl\n feats, pred_angle = preds if isinstance(preds[0], list) else preds[1]\n batch_size = pred_angle.shape[0] # batch size, number of masks, mask height, mask width\n pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(\n (self.reg_max * 4, self.nc), 1\n )\n\n # b, grids, ..\n pred_scores = pred_scores.permute(0, 2, 1).contiguous()\n pred_distri = pred_distri.permute(0, 2, 1).contiguous()\n pred_angle = pred_angle.permute(0, 2, 1).contiguous()\n\n dtype = pred_scores.dtype\n imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)\n anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)\n\n # targets\n try:\n batch_idx = batch[\"batch_idx\"].view(-1, 1)\n targets = torch.cat((batch_idx, batch[\"cls\"].view(-1, 1), batch[\"bboxes\"].view(-1, 5)), 1)\n rw, rh = targets[:, 4] * imgsz[0].item(), targets[:, 5] * imgsz[1].item()\n targets = targets[(rw >= 2) & (rh >= 2)] # filter rboxes of tiny size to stabilize training\n targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])\n gt_labels, gt_bboxes = targets.split((1, 5), 2) # cls, xywhr\n mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)\n except RuntimeError as e:\n raise TypeError(\n \"ERROR ❌ OBB dataset incorrectly formatted or not a OBB dataset.\\n\"\n \"This error can occur when incorrectly training a 'OBB' model on a 'detect' dataset, \"\n \"i.e. 'yolo train model=yolo11n-obb.pt data=coco8.yaml'.\\nVerify your dataset is a \"\n \"correctly formatted 'OBB' dataset using 'data=dota8.yaml' \"\n \"as an example.\\nSee https://docs.ultralytics.com/datasets/obb/ for help.\"\n ) from e\n\n # Pboxes\n pred_bboxes = self.bbox_decode(anchor_points, pred_distri, pred_angle) # xyxy, (b, h*w, 4)\n\n bboxes_for_assigner = pred_bboxes.clone().detach()\n # Only the first four elements need to be scaled\n bboxes_for_assigner[..., :4] *= stride_tensor\n _, target_bboxes, target_scores, fg_mask, _ = self.assigner(\n pred_scores.detach().sigmoid(),\n bboxes_for_assigner.type(gt_bboxes.dtype),\n anchor_points * stride_tensor,\n gt_labels,\n gt_bboxes,\n mask_gt,\n )\n\n target_scores_sum = max(target_scores.sum(), 1)\n\n # Cls loss\n # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way\n loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE\n\n # Bbox loss\n if fg_mask.sum():\n target_bboxes[..., :4] /= stride_tensor\n loss[0], loss[2] = self.bbox_loss(\n pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask\n )\n else:\n loss[0] += (pred_angle * 0).sum()\n\n loss[0] *= self.hyp.box # box gain\n loss[1] *= self.hyp.cls # cls gain\n loss[2] *= self.hyp.dfl # dfl gain\n\n return loss * batch_size, loss.detach() # loss(box, cls, dfl)\n\n def bbox_decode(\n self, anchor_points: torch.Tensor, pred_dist: torch.Tensor, pred_angle: torch.Tensor\n ) -> torch.Tensor:\n \"\"\"\n Decode predicted object bounding box coordinates from anchor points and distribution.\n\n Args:\n anchor_points (torch.Tensor): Anchor points, (h*w, 2).\n pred_dist (torch.Tensor): Predicted rotated distance, (bs, h*w, 4).\n pred_angle (torch.Tensor): Predicted angle, (bs, h*w, 1).\n\n Returns:\n (torch.Tensor): Predicted rotated bounding boxes with angles, (bs, h*w, 5).\n \"\"\"\n if self.use_dfl:\n b, a, c = pred_dist.shape # batch, anchors, channels\n pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))\n return torch.cat((dist2rbox(pred_dist, pred_angle, anchor_points), pred_angle), dim=-1)", "chunk_type": "class", "name": "v8OBBLoss", "file_path": "ultralytics\\ultralytics\\utils\\loss.py", "start_line": 655, "end_line": 770, "start_col": 0, "end_col": 95, "parent_name": null, "docstring": "Calculates losses for object detection, classification, and box distribution in rotated YOLO models.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.metrics.OKS_SIGMA", "ultralytics.utils.ops.crop_mask", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xyxy2xywh", "ultralytics.utils.tal.RotatedTaskAlignedAssigner", "ultralytics.utils.tal.TaskAlignedAssigner", "ultralytics.utils.tal.dist2bbox", "ultralytics.utils.tal.dist2rbox", "ultralytics.utils.tal.make_anchors", "ultralytics.utils.torch_utils.autocast", "metrics.bbox_iou", "metrics.probiou", "tal.bbox2dist", "v8DetectionLoss" ], "chunk_id": "class_v8OBBLoss_9a1d2a7d" }, { "content": "class E2EDetectLoss:\n \"\"\"Criterion class for computing training losses for end-to-end detection.\"\"\"\n\n def __init__(self, model):\n \"\"\"Initialize E2EDetectLoss with one-to-many and one-to-one detection losses using the provided model.\"\"\"\n self.one2many = v8DetectionLoss(model, tal_topk=10)\n self.one2one = v8DetectionLoss(model, tal_topk=1)\n\n def __call__(self, preds: Any, batch: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:\n \"\"\"Calculate the sum of the loss for box, cls and dfl multiplied by batch size.\"\"\"\n preds = preds[1] if isinstance(preds, tuple) else preds\n one2many = preds[\"one2many\"]\n loss_one2many = self.one2many(one2many, batch)\n one2one = preds[\"one2one\"]\n loss_one2one = self.one2one(one2one, batch)\n return loss_one2many[0] + loss_one2one[0], loss_one2many[1] + loss_one2one[1]", "chunk_type": "class", "name": "E2EDetectLoss", "file_path": "ultralytics\\ultralytics\\utils\\loss.py", "start_line": 773, "end_line": 788, "start_col": 0, "end_col": 85, "parent_name": null, "docstring": "Criterion class for computing training losses for end-to-end detection.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.metrics.OKS_SIGMA", "ultralytics.utils.ops.crop_mask", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xyxy2xywh", "ultralytics.utils.tal.RotatedTaskAlignedAssigner", "ultralytics.utils.tal.TaskAlignedAssigner", "ultralytics.utils.tal.dist2bbox", "ultralytics.utils.tal.dist2rbox", "ultralytics.utils.tal.make_anchors", "ultralytics.utils.torch_utils.autocast", "metrics.bbox_iou", "metrics.probiou", "tal.bbox2dist" ], "chunk_id": "class_E2EDetectLoss_584a97b6" }, { "content": "class TVPDetectLoss:\n \"\"\"Criterion class for computing training losses for text-visual prompt detection.\"\"\"\n\n def __init__(self, model):\n \"\"\"Initialize TVPDetectLoss with task-prompt and visual-prompt criteria using the provided model.\"\"\"\n self.vp_criterion = v8DetectionLoss(model)\n # NOTE: store following info as it's changeable in __call__\n self.ori_nc = self.vp_criterion.nc\n self.ori_no = self.vp_criterion.no\n self.ori_reg_max = self.vp_criterion.reg_max\n\n def __call__(self, preds: Any, batch: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:\n \"\"\"Calculate the loss for text-visual prompt detection.\"\"\"\n feats = preds[1] if isinstance(preds, tuple) else preds\n assert self.ori_reg_max == self.vp_criterion.reg_max # TODO: remove it\n\n if self.ori_reg_max * 4 + self.ori_nc == feats[0].shape[1]:\n loss = torch.zeros(3, device=self.vp_criterion.device, requires_grad=True)\n return loss, loss.detach()\n\n vp_feats = self._get_vp_features(feats)\n vp_loss = self.vp_criterion(vp_feats, batch)\n box_loss = vp_loss[0][1]\n return box_loss, vp_loss[1]\n\n def _get_vp_features(self, feats: List[torch.Tensor]) -> List[torch.Tensor]:\n \"\"\"Extract visual-prompt features from the model output.\"\"\"\n vnc = feats[0].shape[1] - self.ori_reg_max * 4 - self.ori_nc\n\n self.vp_criterion.nc = vnc\n self.vp_criterion.no = vnc + self.vp_criterion.reg_max * 4\n self.vp_criterion.assigner.num_classes = vnc\n\n return [\n torch.cat((box, cls_vp), dim=1)\n for box, _, cls_vp in [xi.split((self.ori_reg_max * 4, self.ori_nc, vnc), dim=1) for xi in feats]\n ]", "chunk_type": "class", "name": "TVPDetectLoss", "file_path": "ultralytics\\ultralytics\\utils\\loss.py", "start_line": 791, "end_line": 827, "start_col": 0, "end_col": 9, "parent_name": null, "docstring": "Criterion class for computing training losses for text-visual prompt detection.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.metrics.OKS_SIGMA", "ultralytics.utils.ops.crop_mask", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xyxy2xywh", "ultralytics.utils.tal.RotatedTaskAlignedAssigner", "ultralytics.utils.tal.TaskAlignedAssigner", "ultralytics.utils.tal.dist2bbox", "ultralytics.utils.tal.dist2rbox", "ultralytics.utils.tal.make_anchors", "ultralytics.utils.torch_utils.autocast", "metrics.bbox_iou", "metrics.probiou", "tal.bbox2dist" ], "chunk_id": "class_TVPDetectLoss_3756ee0d" }, { "content": "class TVPSegmentLoss(TVPDetectLoss):\n \"\"\"Criterion class for computing training losses for text-visual prompt segmentation.\"\"\"\n\n def __init__(self, model):\n \"\"\"Initialize TVPSegmentLoss with task-prompt and visual-prompt criteria using the provided model.\"\"\"\n super().__init__(model)\n self.vp_criterion = v8SegmentationLoss(model)\n\n def __call__(self, preds: Any, batch: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:\n \"\"\"Calculate the loss for text-visual prompt segmentation.\"\"\"\n feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]\n assert self.ori_reg_max == self.vp_criterion.reg_max # TODO: remove it\n\n if self.ori_reg_max * 4 + self.ori_nc == feats[0].shape[1]:\n loss = torch.zeros(4, device=self.vp_criterion.device, requires_grad=True)\n return loss, loss.detach()\n\n vp_feats = self._get_vp_features(feats)\n vp_loss = self.vp_criterion((vp_feats, pred_masks, proto), batch)\n cls_loss = vp_loss[0][2]\n return cls_loss, vp_loss[1]", "chunk_type": "class", "name": "TVPSegmentLoss", "file_path": "ultralytics\\ultralytics\\utils\\loss.py", "start_line": 830, "end_line": 850, "start_col": 0, "end_col": 35, "parent_name": null, "docstring": "Criterion class for computing training losses for text-visual prompt segmentation.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.metrics.OKS_SIGMA", "ultralytics.utils.ops.crop_mask", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xyxy2xywh", "ultralytics.utils.tal.RotatedTaskAlignedAssigner", "ultralytics.utils.tal.TaskAlignedAssigner", "ultralytics.utils.tal.dist2bbox", "ultralytics.utils.tal.dist2rbox", "ultralytics.utils.tal.make_anchors", "ultralytics.utils.torch_utils.autocast", "metrics.bbox_iou", "metrics.probiou", "tal.bbox2dist", "TVPDetectLoss" ], "chunk_id": "class_TVPSegmentLoss_36d5dc11" }, { "content": "import math", "chunk_type": "import", "name": "math", "file_path": "ultralytics\\ultralytics\\utils\\metrics.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 11, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_math_e72c7ed4" }, { "content": "import warnings", "chunk_type": "import", "name": "warnings", "file_path": "ultralytics\\ultralytics\\utils\\metrics.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 15, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_warnings_6a87eafa" }, { "content": "from pathlib import Path", "chunk_type": "import", "name": "Path", "file_path": "ultralytics\\ultralytics\\utils\\metrics.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Path_5c08fd5b" }, { "content": "from typing import Any, Dict, List, Tuple, Union", "chunk_type": "import", "name": "Any, Dict, List, Tuple, Union", "file_path": "ultralytics\\ultralytics\\utils\\metrics.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 48, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Any, Dict, List, Tuple, Union_6ad3df08" }, { "content": "import numpy as np", "chunk_type": "import", "name": "numpy", "file_path": "ultralytics\\ultralytics\\utils\\metrics.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_numpy_bfe3320a" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\ultralytics\\utils\\metrics.py", "start_line": 10, "end_line": 10, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_855bdd3f" }, { "content": "from ultralytics.utils import LOGGER, DataExportMixin, SimpleClass, TryExcept, checks, plt_settings", "chunk_type": "import", "name": "LOGGER, DataExportMixin, SimpleClass, TryExcept, checks, plt_settings", "file_path": "ultralytics\\ultralytics\\utils\\metrics.py", "start_line": 12, "end_line": 12, "start_col": 0, "end_col": 99, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_LOGGER, DataExportMixin, SimpleClass, TryExcept, checks, plt_settings_95373bc3" }, { "content": "OKS_SIGMA = (\n np.array([0.26, 0.25, 0.25, 0.35, 0.35, 0.79, 0.79, 0.72, 0.72, 0.62, 0.62, 1.07, 1.07, 0.87, 0.87, 0.89, 0.89])\n / 10.0\n)", "chunk_type": "variable", "name": "OKS_SIGMA", "file_path": "ultralytics\\ultralytics\\utils\\metrics.py", "start_line": 14, "end_line": 17, "start_col": 0, "end_col": 1, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_OKS_SIGMA_162900b5" }, { "content": "def bbox_ioa(box1: np.ndarray, box2: np.ndarray, iou: bool = False, eps: float = 1e-7) -> np.ndarray:\n \"\"\"\n Calculate the intersection over box2 area given box1 and box2.\n\n Args:\n box1 (np.ndarray): A numpy array of shape (N, 4) representing N bounding boxes in x1y1x2y2 format.\n box2 (np.ndarray): A numpy array of shape (M, 4) representing M bounding boxes in x1y1x2y2 format.\n iou (bool, optional): Calculate the standard IoU if True else return inter_area/box2_area.\n eps (float, optional): A small value to avoid division by zero.\n\n Returns:\n (np.ndarray): A numpy array of shape (N, M) representing the intersection over box2 area.\n \"\"\"\n # Get the coordinates of bounding boxes\n b1_x1, b1_y1, b1_x2, b1_y2 = box1.T\n b2_x1, b2_y1, b2_x2, b2_y2 = box2.T\n\n # Intersection area\n inter_area = (np.minimum(b1_x2[:, None], b2_x2) - np.maximum(b1_x1[:, None], b2_x1)).clip(0) * (\n np.minimum(b1_y2[:, None], b2_y2) - np.maximum(b1_y1[:, None], b2_y1)\n ).clip(0)\n\n # Box2 area\n area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1)\n if iou:\n box1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1)\n area = area + box1_area[:, None] - inter_area\n\n # Intersection over box2 area\n return inter_area / (area + eps)", "chunk_type": "function", "name": "bbox_ioa", "file_path": "ultralytics\\ultralytics\\utils\\metrics.py", "start_line": 20, "end_line": 49, "start_col": 0, "end_col": 36, "parent_name": null, "docstring": "Calculate the intersection over box2 area given box1 and box2.\n\nArgs:\n box1 (np.ndarray): A numpy array of shape (N, 4) representing N bounding boxes in x1y1x2y2 format.\n box2 (np.ndarray): A numpy array of shape (M, 4) representing M bounding boxes in x1y1x2y2 format.\n iou (bool, optional): Calculate the standard IoU if True else return inter_area/box2_area.\n eps (float, optional): A small value to avoid division by zero.\n\nReturns:\n (np.ndarray): A numpy array of shape (N, M) representing the intersection over box2 area.", "parameters": [ "box1: np.ndarray", "box2: np.ndarray", "iou: bool", "eps: float" ], "return_type": "np.ndarray", "decorators": [], "complexity_score": 2, "dependencies": [ "math", "warnings", "pathlib.Path", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "typing.Union", "numpy", "torch", "ultralytics.utils.LOGGER", "ultralytics.utils.DataExportMixin", "ultralytics.utils.SimpleClass", "ultralytics.utils.TryExcept", "ultralytics.utils.checks", "ultralytics.utils.plt_settings", "matplotlib.pyplot", "matplotlib.pyplot", "matplotlib.pyplot", "re" ], "chunk_id": "function_bbox_ioa_55f2214d" }, { "content": "def box_iou(box1: torch.Tensor, box2: torch.Tensor, eps: float = 1e-7) -> torch.Tensor:\n \"\"\"\n Calculate intersection-over-union (IoU) of boxes.\n\n Args:\n box1 (torch.Tensor): A tensor of shape (N, 4) representing N bounding boxes in (x1, y1, x2, y2) format.\n box2 (torch.Tensor): A tensor of shape (M, 4) representing M bounding boxes in (x1, y1, x2, y2) format.\n eps (float, optional): A small value to avoid division by zero.\n\n Returns:\n (torch.Tensor): An NxM tensor containing the pairwise IoU values for every element in box1 and box2.\n\n References:\n https://github.com/pytorch/vision/blob/main/torchvision/ops/boxes.py\n \"\"\"\n # NOTE: Need .float() to get accurate iou values\n # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)\n (a1, a2), (b1, b2) = box1.float().unsqueeze(1).chunk(2, 2), box2.float().unsqueeze(0).chunk(2, 2)\n inter = (torch.min(a2, b2) - torch.max(a1, b1)).clamp_(0).prod(2)\n\n # IoU = inter / (area1 + area2 - inter)\n return inter / ((a2 - a1).prod(2) + (b2 - b1).prod(2) - inter + eps)", "chunk_type": "function", "name": "box_iou", "file_path": "ultralytics\\ultralytics\\utils\\metrics.py", "start_line": 52, "end_line": 73, "start_col": 0, "end_col": 72, "parent_name": null, "docstring": "Calculate intersection-over-union (IoU) of boxes.\n\nArgs:\n box1 (torch.Tensor): A tensor of shape (N, 4) representing N bounding boxes in (x1, y1, x2, y2) format.\n box2 (torch.Tensor): A tensor of shape (M, 4) representing M bounding boxes in (x1, y1, x2, y2) format.\n eps (float, optional): A small value to avoid division by zero.\n\nReturns:\n (torch.Tensor): An NxM tensor containing the pairwise IoU values for every element in box1 and box2.\n\nReferences:\n https://github.com/pytorch/vision/blob/main/torchvision/ops/boxes.py", "parameters": [ "box1: torch.Tensor", "box2: torch.Tensor", "eps: float" ], "return_type": "torch.Tensor", "decorators": [], "complexity_score": 1, "dependencies": [ "math", "warnings", "pathlib.Path", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "typing.Union", "numpy", "torch", "ultralytics.utils.LOGGER", "ultralytics.utils.DataExportMixin", "ultralytics.utils.SimpleClass", "ultralytics.utils.TryExcept", "ultralytics.utils.checks", "ultralytics.utils.plt_settings", "matplotlib.pyplot", "matplotlib.pyplot", "matplotlib.pyplot", "re" ], "chunk_id": "function_box_iou_6f011eb8" }, { "content": "def bbox_iou(\n box1: torch.Tensor,\n box2: torch.Tensor,\n xywh: bool = True,\n GIoU: bool = False,\n DIoU: bool = False,\n CIoU: bool = False,\n eps: float = 1e-7,\n) -> torch.Tensor:\n \"\"\"\n Calculate the Intersection over Union (IoU) between bounding boxes.\n\n This function supports various shapes for `box1` and `box2` as long as the last dimension is 4.\n For instance, you may pass tensors shaped like (4,), (N, 4), (B, N, 4), or (B, N, 1, 4).\n Internally, the code will split the last dimension into (x, y, w, h) if `xywh=True`,\n or (x1, y1, x2, y2) if `xywh=False`.\n\n Args:\n box1 (torch.Tensor): A tensor representing one or more bounding boxes, with the last dimension being 4.\n box2 (torch.Tensor): A tensor representing one or more bounding boxes, with the last dimension being 4.\n xywh (bool, optional): If True, input boxes are in (x, y, w, h) format. If False, input boxes are in\n (x1, y1, x2, y2) format.\n GIoU (bool, optional): If True, calculate Generalized IoU.\n DIoU (bool, optional): If True, calculate Distance IoU.\n CIoU (bool, optional): If True, calculate Complete IoU.\n eps (float, optional): A small value to avoid division by zero.\n\n Returns:\n (torch.Tensor): IoU, GIoU, DIoU, or CIoU values depending on the specified flags.\n \"\"\"\n # Get the coordinates of bounding boxes\n if xywh: # transform from xywh to xyxy\n (x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, -1), box2.chunk(4, -1)\n w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2\n b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_\n b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_\n else: # x1, y1, x2, y2 = box1\n b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, -1)\n b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, -1)\n w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps\n w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps\n\n # Intersection area\n inter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp_(0) * (\n b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1)\n ).clamp_(0)\n\n # Union Area\n union = w1 * h1 + w2 * h2 - inter + eps\n\n # IoU\n iou = inter / union\n if CIoU or DIoU or GIoU:\n cw = b1_x2.maximum(b2_x2) - b1_x1.minimum(b2_x1) # convex (smallest enclosing box) width\n ch = b1_y2.maximum(b2_y2) - b1_y1.minimum(b2_y1) # convex height\n if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1\n c2 = cw.pow(2) + ch.pow(2) + eps # convex diagonal squared\n rho2 = (\n (b2_x1 + b2_x2 - b1_x1 - b1_x2).pow(2) + (b2_y1 + b2_y2 - b1_y1 - b1_y2).pow(2)\n ) / 4 # center dist**2\n if CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47\n v = (4 / math.pi**2) * ((w2 / h2).atan() - (w1 / h1).atan()).pow(2)\n with torch.no_grad():\n alpha = v / (v - iou + (1 + eps))\n return iou - (rho2 / c2 + v * alpha) # CIoU\n return iou - rho2 / c2 # DIoU\n c_area = cw * ch + eps # convex area\n return iou - (c_area - union) / c_area # GIoU https://arxiv.org/pdf/1902.09630.pdf\n return iou # IoU", "chunk_type": "function", "name": "bbox_iou", "file_path": "ultralytics\\ultralytics\\utils\\metrics.py", "start_line": 76, "end_line": 144, "start_col": 0, "end_col": 14, "parent_name": null, "docstring": "Calculate the Intersection over Union (IoU) between bounding boxes.\n\nThis function supports various shapes for `box1` and `box2` as long as the last dimension is 4.\nFor instance, you may pass tensors shaped like (4,), (N, 4), (B, N, 4), or (B, N, 1, 4).\nInternally, the code will split the last dimension into (x, y, w, h) if `xywh=True`,\nor (x1, y1, x2, y2) if `xywh=False`.\n\nArgs:\n box1 (torch.Tensor): A tensor representing one or more bounding boxes, with the last dimension being 4.\n box2 (torch.Tensor): A tensor representing one or more bounding boxes, with the last dimension being 4.\n xywh (bool, optional): If True, input boxes are in (x, y, w, h) format. If False, input boxes are in\n (x1, y1, x2, y2) format.\n GIoU (bool, optional): If True, calculate Generalized IoU.\n DIoU (bool, optional): If True, calculate Distance IoU.\n CIoU (bool, optional): If True, calculate Complete IoU.\n eps (float, optional): A small value to avoid division by zero.\n\nReturns:\n (torch.Tensor): IoU, GIoU, DIoU, or CIoU values depending on the specified flags.", "parameters": [ "box1: torch.Tensor", "box2: torch.Tensor", "xywh: bool", "GIoU: bool", "DIoU: bool", "CIoU: bool", "eps: float" ], "return_type": "torch.Tensor", "decorators": [], "complexity_score": 5, "dependencies": [ "math", "warnings", "pathlib.Path", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "typing.Union", "numpy", "torch", "ultralytics.utils.LOGGER", "ultralytics.utils.DataExportMixin", "ultralytics.utils.SimpleClass", "ultralytics.utils.TryExcept", "ultralytics.utils.checks", "ultralytics.utils.plt_settings", "matplotlib.pyplot", "matplotlib.pyplot", "matplotlib.pyplot", "re" ], "chunk_id": "function_bbox_iou_459ddd01" }, { "content": "def mask_iou(mask1: torch.Tensor, mask2: torch.Tensor, eps: float = 1e-7) -> torch.Tensor:\n \"\"\"\n Calculate masks IoU.\n\n Args:\n mask1 (torch.Tensor): A tensor of shape (N, n) where N is the number of ground truth objects and n is the\n product of image width and height.\n mask2 (torch.Tensor): A tensor of shape (M, n) where M is the number of predicted objects and n is the\n product of image width and height.\n eps (float, optional): A small value to avoid division by zero.\n\n Returns:\n (torch.Tensor): A tensor of shape (N, M) representing masks IoU.\n \"\"\"\n intersection = torch.matmul(mask1, mask2.T).clamp_(0)\n union = (mask1.sum(1)[:, None] + mask2.sum(1)[None]) - intersection # (area1 + area2) - intersection\n return intersection / (union + eps)", "chunk_type": "function", "name": "mask_iou", "file_path": "ultralytics\\ultralytics\\utils\\metrics.py", "start_line": 147, "end_line": 163, "start_col": 0, "end_col": 39, "parent_name": null, "docstring": "Calculate masks IoU.\n\nArgs:\n mask1 (torch.Tensor): A tensor of shape (N, n) where N is the number of ground truth objects and n is the\n product of image width and height.\n mask2 (torch.Tensor): A tensor of shape (M, n) where M is the number of predicted objects and n is the\n product of image width and height.\n eps (float, optional): A small value to avoid division by zero.\n\nReturns:\n (torch.Tensor): A tensor of shape (N, M) representing masks IoU.", "parameters": [ "mask1: torch.Tensor", "mask2: torch.Tensor", "eps: float" ], "return_type": "torch.Tensor", "decorators": [], "complexity_score": 1, "dependencies": [ "math", "warnings", "pathlib.Path", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "typing.Union", "numpy", "torch", "ultralytics.utils.LOGGER", "ultralytics.utils.DataExportMixin", "ultralytics.utils.SimpleClass", "ultralytics.utils.TryExcept", "ultralytics.utils.checks", "ultralytics.utils.plt_settings", "matplotlib.pyplot", "matplotlib.pyplot", "matplotlib.pyplot", "re" ], "chunk_id": "function_mask_iou_acb5a26e" }, { "content": "def kpt_iou(\n kpt1: torch.Tensor, kpt2: torch.Tensor, area: torch.Tensor, sigma: List[float], eps: float = 1e-7\n) -> torch.Tensor:\n \"\"\"\n Calculate Object Keypoint Similarity (OKS).\n\n Args:\n kpt1 (torch.Tensor): A tensor of shape (N, 17, 3) representing ground truth keypoints.\n kpt2 (torch.Tensor): A tensor of shape (M, 17, 3) representing predicted keypoints.\n area (torch.Tensor): A tensor of shape (N,) representing areas from ground truth.\n sigma (list): A list containing 17 values representing keypoint scales.\n eps (float, optional): A small value to avoid division by zero.\n\n Returns:\n (torch.Tensor): A tensor of shape (N, M) representing keypoint similarities.\n \"\"\"\n d = (kpt1[:, None, :, 0] - kpt2[..., 0]).pow(2) + (kpt1[:, None, :, 1] - kpt2[..., 1]).pow(2) # (N, M, 17)\n sigma = torch.tensor(sigma, device=kpt1.device, dtype=kpt1.dtype) # (17, )\n kpt_mask = kpt1[..., 2] != 0 # (N, 17)\n e = d / ((2 * sigma).pow(2) * (area[:, None, None] + eps) * 2) # from cocoeval\n # e = d / ((area[None, :, None] + eps) * sigma) ** 2 / 2 # from formula\n return ((-e).exp() * kpt_mask[:, None]).sum(-1) / (kpt_mask.sum(-1)[:, None] + eps)", "chunk_type": "function", "name": "kpt_iou", "file_path": "ultralytics\\ultralytics\\utils\\metrics.py", "start_line": 166, "end_line": 187, "start_col": 0, "end_col": 87, "parent_name": null, "docstring": "Calculate Object Keypoint Similarity (OKS).\n\nArgs:\n kpt1 (torch.Tensor): A tensor of shape (N, 17, 3) representing ground truth keypoints.\n kpt2 (torch.Tensor): A tensor of shape (M, 17, 3) representing predicted keypoints.\n area (torch.Tensor): A tensor of shape (N,) representing areas from ground truth.\n sigma (list): A list containing 17 values representing keypoint scales.\n eps (float, optional): A small value to avoid division by zero.\n\nReturns:\n (torch.Tensor): A tensor of shape (N, M) representing keypoint similarities.", "parameters": [ "kpt1: torch.Tensor", "kpt2: torch.Tensor", "area: torch.Tensor", "sigma: List[float]", "eps: float" ], "return_type": "torch.Tensor", "decorators": [], "complexity_score": 1, "dependencies": [ "math", "warnings", "pathlib.Path", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "typing.Union", "numpy", "torch", "ultralytics.utils.LOGGER", "ultralytics.utils.DataExportMixin", "ultralytics.utils.SimpleClass", "ultralytics.utils.TryExcept", "ultralytics.utils.checks", "ultralytics.utils.plt_settings", "matplotlib.pyplot", "matplotlib.pyplot", "matplotlib.pyplot", "re" ], "chunk_id": "function_kpt_iou_95674ad5" }, { "content": "def _get_covariance_matrix(boxes: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n \"\"\"\n Generate covariance matrix from oriented bounding boxes.\n\n Args:\n boxes (torch.Tensor): A tensor of shape (N, 5) representing rotated bounding boxes, with xywhr format.\n\n Returns:\n (torch.Tensor): Covariance matrices corresponding to original rotated bounding boxes.\n \"\"\"\n # Gaussian bounding boxes, ignore the center points (the first two columns) because they are not needed here.\n gbbs = torch.cat((boxes[:, 2:4].pow(2) / 12, boxes[:, 4:]), dim=-1)\n a, b, c = gbbs.split(1, dim=-1)\n cos = c.cos()\n sin = c.sin()\n cos2 = cos.pow(2)\n sin2 = sin.pow(2)\n return a * cos2 + b * sin2, a * sin2 + b * cos2, (a - b) * cos * sin", "chunk_type": "function", "name": "_get_covariance_matrix", "file_path": "ultralytics\\ultralytics\\utils\\metrics.py", "start_line": 190, "end_line": 207, "start_col": 0, "end_col": 72, "parent_name": null, "docstring": "Generate covariance matrix from oriented bounding boxes.\n\nArgs:\n boxes (torch.Tensor): A tensor of shape (N, 5) representing rotated bounding boxes, with xywhr format.\n\nReturns:\n (torch.Tensor): Covariance matrices corresponding to original rotated bounding boxes.", "parameters": [ "boxes: torch.Tensor" ], "return_type": "Tuple[torch.Tensor, torch.Tensor, torch.Tensor]", "decorators": [], "complexity_score": 1, "dependencies": [ "math", "warnings", "pathlib.Path", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "typing.Union", "numpy", "torch", "ultralytics.utils.LOGGER", "ultralytics.utils.DataExportMixin", "ultralytics.utils.SimpleClass", "ultralytics.utils.TryExcept", "ultralytics.utils.checks", "ultralytics.utils.plt_settings", "matplotlib.pyplot", "matplotlib.pyplot", "matplotlib.pyplot", "re" ], "chunk_id": "function__get_covariance_matrix_9af498db" }, { "content": "def probiou(obb1: torch.Tensor, obb2: torch.Tensor, CIoU: bool = False, eps: float = 1e-7) -> torch.Tensor:\n \"\"\"\n Calculate probabilistic IoU between oriented bounding boxes.\n\n Args:\n obb1 (torch.Tensor): Ground truth OBBs, shape (N, 5), format xywhr.\n obb2 (torch.Tensor): Predicted OBBs, shape (N, 5), format xywhr.\n CIoU (bool, optional): If True, calculate CIoU.\n eps (float, optional): Small value to avoid division by zero.\n\n Returns:\n (torch.Tensor): OBB similarities, shape (N,).\n\n Notes:\n OBB format: [center_x, center_y, width, height, rotation_angle].\n\n References:\n https://arxiv.org/pdf/2106.06072v1.pdf\n \"\"\"\n x1, y1 = obb1[..., :2].split(1, dim=-1)\n x2, y2 = obb2[..., :2].split(1, dim=-1)\n a1, b1, c1 = _get_covariance_matrix(obb1)\n a2, b2, c2 = _get_covariance_matrix(obb2)\n\n t1 = (\n ((a1 + a2) * (y1 - y2).pow(2) + (b1 + b2) * (x1 - x2).pow(2)) / ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2) + eps)\n ) * 0.25\n t2 = (((c1 + c2) * (x2 - x1) * (y1 - y2)) / ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2) + eps)) * 0.5\n t3 = (\n ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2))\n / (4 * ((a1 * b1 - c1.pow(2)).clamp_(0) * (a2 * b2 - c2.pow(2)).clamp_(0)).sqrt() + eps)\n + eps\n ).log() * 0.5\n bd = (t1 + t2 + t3).clamp(eps, 100.0)\n hd = (1.0 - (-bd).exp() + eps).sqrt()\n iou = 1 - hd\n if CIoU: # only include the wh aspect ratio part\n w1, h1 = obb1[..., 2:4].split(1, dim=-1)\n w2, h2 = obb2[..., 2:4].split(1, dim=-1)\n v = (4 / math.pi**2) * ((w2 / h2).atan() - (w1 / h1).atan()).pow(2)\n with torch.no_grad():\n alpha = v / (v - iou + (1 + eps))\n return iou - v * alpha # CIoU\n return iou", "chunk_type": "function", "name": "probiou", "file_path": "ultralytics\\ultralytics\\utils\\metrics.py", "start_line": 210, "end_line": 253, "start_col": 0, "end_col": 14, "parent_name": null, "docstring": "Calculate probabilistic IoU between oriented bounding boxes.\n\nArgs:\n obb1 (torch.Tensor): Ground truth OBBs, shape (N, 5), format xywhr.\n obb2 (torch.Tensor): Predicted OBBs, shape (N, 5), format xywhr.\n CIoU (bool, optional): If True, calculate CIoU.\n eps (float, optional): Small value to avoid division by zero.\n\nReturns:\n (torch.Tensor): OBB similarities, shape (N,).\n\nNotes:\n OBB format: [center_x, center_y, width, height, rotation_angle].\n\nReferences:\n https://arxiv.org/pdf/2106.06072v1.pdf", "parameters": [ "obb1: torch.Tensor", "obb2: torch.Tensor", "CIoU: bool", "eps: float" ], "return_type": "torch.Tensor", "decorators": [], "complexity_score": 2, "dependencies": [ "math", "warnings", "pathlib.Path", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "typing.Union", "numpy", "torch", "ultralytics.utils.LOGGER", "ultralytics.utils.DataExportMixin", "ultralytics.utils.SimpleClass", "ultralytics.utils.TryExcept", "ultralytics.utils.checks", "ultralytics.utils.plt_settings", "matplotlib.pyplot", "matplotlib.pyplot", "matplotlib.pyplot", "re" ], "chunk_id": "function_probiou_78b462b7" }, { "content": "def batch_probiou(\n obb1: Union[torch.Tensor, np.ndarray], obb2: Union[torch.Tensor, np.ndarray], eps: float = 1e-7\n) -> torch.Tensor:\n \"\"\"\n Calculate the probabilistic IoU between oriented bounding boxes.\n\n Args:\n obb1 (torch.Tensor | np.ndarray): A tensor of shape (N, 5) representing ground truth obbs, with xywhr format.\n obb2 (torch.Tensor | np.ndarray): A tensor of shape (M, 5) representing predicted obbs, with xywhr format.\n eps (float, optional): A small value to avoid division by zero.\n\n Returns:\n (torch.Tensor): A tensor of shape (N, M) representing obb similarities.\n\n References:\n https://arxiv.org/pdf/2106.06072v1.pdf\n \"\"\"\n obb1 = torch.from_numpy(obb1) if isinstance(obb1, np.ndarray) else obb1\n obb2 = torch.from_numpy(obb2) if isinstance(obb2, np.ndarray) else obb2\n\n x1, y1 = obb1[..., :2].split(1, dim=-1)\n x2, y2 = (x.squeeze(-1)[None] for x in obb2[..., :2].split(1, dim=-1))\n a1, b1, c1 = _get_covariance_matrix(obb1)\n a2, b2, c2 = (x.squeeze(-1)[None] for x in _get_covariance_matrix(obb2))\n\n t1 = (\n ((a1 + a2) * (y1 - y2).pow(2) + (b1 + b2) * (x1 - x2).pow(2)) / ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2) + eps)\n ) * 0.25\n t2 = (((c1 + c2) * (x2 - x1) * (y1 - y2)) / ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2) + eps)) * 0.5\n t3 = (\n ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2))\n / (4 * ((a1 * b1 - c1.pow(2)).clamp_(0) * (a2 * b2 - c2.pow(2)).clamp_(0)).sqrt() + eps)\n + eps\n ).log() * 0.5\n bd = (t1 + t2 + t3).clamp(eps, 100.0)\n hd = (1.0 - (-bd).exp() + eps).sqrt()\n return 1 - hd", "chunk_type": "function", "name": "batch_probiou", "file_path": "ultralytics\\ultralytics\\utils\\metrics.py", "start_line": 256, "end_line": 292, "start_col": 0, "end_col": 17, "parent_name": null, "docstring": "Calculate the probabilistic IoU between oriented bounding boxes.\n\nArgs:\n obb1 (torch.Tensor | np.ndarray): A tensor of shape (N, 5) representing ground truth obbs, with xywhr format.\n obb2 (torch.Tensor | np.ndarray): A tensor of shape (M, 5) representing predicted obbs, with xywhr format.\n eps (float, optional): A small value to avoid division by zero.\n\nReturns:\n (torch.Tensor): A tensor of shape (N, M) representing obb similarities.\n\nReferences:\n https://arxiv.org/pdf/2106.06072v1.pdf", "parameters": [ "obb1: Union[torch.Tensor, np.ndarray]", "obb2: Union[torch.Tensor, np.ndarray]", "eps: float" ], "return_type": "torch.Tensor", "decorators": [], "complexity_score": 3, "dependencies": [ "math", "warnings", "pathlib.Path", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "typing.Union", "numpy", "torch", "ultralytics.utils.LOGGER", "ultralytics.utils.DataExportMixin", "ultralytics.utils.SimpleClass", "ultralytics.utils.TryExcept", "ultralytics.utils.checks", "ultralytics.utils.plt_settings", "matplotlib.pyplot", "matplotlib.pyplot", "matplotlib.pyplot", "re" ], "chunk_id": "function_batch_probiou_1b0ac7e4" }, { "content": "def smooth_bce(eps: float = 0.1) -> Tuple[float, float]:\n \"\"\"\n Compute smoothed positive and negative Binary Cross-Entropy targets.\n\n Args:\n eps (float, optional): The epsilon value for label smoothing.\n\n Returns:\n pos (float): Positive label smoothing BCE target.\n neg (float): Negative label smoothing BCE target.\n\n References:\n https://github.com/ultralytics/yolov3/issues/238#issuecomment-598028441\n \"\"\"\n return 1.0 - 0.5 * eps, 0.5 * eps", "chunk_type": "function", "name": "smooth_bce", "file_path": "ultralytics\\ultralytics\\utils\\metrics.py", "start_line": 295, "end_line": 309, "start_col": 0, "end_col": 37, "parent_name": null, "docstring": "Compute smoothed positive and negative Binary Cross-Entropy targets.\n\nArgs:\n eps (float, optional): The epsilon value for label smoothing.\n\nReturns:\n pos (float): Positive label smoothing BCE target.\n neg (float): Negative label smoothing BCE target.\n\nReferences:\n https://github.com/ultralytics/yolov3/issues/238#issuecomment-598028441", "parameters": [ "eps: float" ], "return_type": "Tuple[float, float]", "decorators": [], "complexity_score": 1, "dependencies": [ "math", "warnings", "pathlib.Path", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "typing.Union", "numpy", "torch", "ultralytics.utils.LOGGER", "ultralytics.utils.DataExportMixin", "ultralytics.utils.SimpleClass", "ultralytics.utils.TryExcept", "ultralytics.utils.checks", "ultralytics.utils.plt_settings", "matplotlib.pyplot", "matplotlib.pyplot", "matplotlib.pyplot", "re" ], "chunk_id": "function_smooth_bce_94ea830c" }, { "content": "class ConfusionMatrix(DataExportMixin):\n \"\"\"\n A class for calculating and updating a confusion matrix for object detection and classification tasks.\n\n Attributes:\n task (str): The type of task, either 'detect' or 'classify'.\n matrix (np.ndarray): The confusion matrix, with dimensions depending on the task.\n nc (int): The number of category.\n names (List[str]): The names of the classes, used as labels on the plot.\n \"\"\"\n\n def __init__(self, names: List[str] = [], task: str = \"detect\"):\n \"\"\"\n Initialize a ConfusionMatrix instance.\n\n Args:\n names (List[str], optional): Names of classes, used as labels on the plot.\n task (str, optional): Type of task, either 'detect' or 'classify'.\n \"\"\"\n self.task = task\n self.nc = len(names) # number of classes\n self.matrix = np.zeros((self.nc, self.nc)) if self.task == \"classify\" else np.zeros((self.nc + 1, self.nc + 1))\n self.names = names # name of classes\n\n def process_cls_preds(self, preds, targets):\n \"\"\"\n Update confusion matrix for classification task.\n\n Args:\n preds (Array[N, min(nc,5)]): Predicted class labels.\n targets (Array[N, 1]): Ground truth class labels.\n \"\"\"\n preds, targets = torch.cat(preds)[:, 0], torch.cat(targets)\n for p, t in zip(preds.cpu().numpy(), targets.cpu().numpy()):\n self.matrix[p][t] += 1\n\n def process_batch(\n self, detections: Dict[str, torch.Tensor], batch: Dict[str, Any], conf: float = 0.25, iou_thres: float = 0.45\n ) -> None:\n \"\"\"\n Update confusion matrix for object detection task.\n\n Args:\n detections (Dict[str, torch.Tensor]): Dictionary containing detected bounding boxes and their associated information.\n Should contain 'cls', 'conf', and 'bboxes' keys, where 'bboxes' can be\n Array[N, 4] for regular boxes or Array[N, 5] for OBB with angle.\n batch (Dict[str, Any]): Batch dictionary containing ground truth data with 'bboxes' (Array[M, 4]| Array[M, 5]) and\n 'cls' (Array[M]) keys, where M is the number of ground truth objects.\n conf (float, optional): Confidence threshold for detections.\n iou_thres (float, optional): IoU threshold for matching detections to ground truth.\n \"\"\"\n gt_cls, gt_bboxes = batch[\"cls\"], batch[\"bboxes\"]\n is_obb = gt_bboxes.shape[1] == 5 # check if boxes contains angle for OBB\n conf = 0.25 if conf in {None, 0.01 if is_obb else 0.001} else conf # apply 0.25 if default val conf is passed\n no_pred = len(detections[\"cls\"]) == 0\n if gt_cls.shape[0] == 0: # Check if labels is empty\n if not no_pred:\n detections = {k: detections[k][detections[\"conf\"] > conf] for k in {\"cls\", \"bboxes\"}}\n detection_classes = detections[\"cls\"].int().tolist()\n for dc in detection_classes:\n self.matrix[dc, self.nc] += 1 # false positives\n return\n if no_pred:\n gt_classes = gt_cls.int().tolist()\n for gc in gt_classes:\n self.matrix[self.nc, gc] += 1 # background FN\n return\n\n detections = {k: detections[k][detections[\"conf\"] > conf] for k in {\"cls\", \"bboxes\"}}\n gt_classes = gt_cls.int().tolist()\n detection_classes = detections[\"cls\"].int().tolist()\n bboxes = detections[\"bboxes\"]\n iou = batch_probiou(gt_bboxes, bboxes) if is_obb else box_iou(gt_bboxes, bboxes)\n\n x = torch.where(iou > iou_thres)\n if x[0].shape[0]:\n matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()\n if x[0].shape[0] > 1:\n matches = matches[matches[:, 2].argsort()[::-1]]\n matches = matches[np.unique(matches[:, 1], return_index=True)[1]]\n matches = matches[matches[:, 2].argsort()[::-1]]\n matches = matches[np.unique(matches[:, 0], return_index=True)[1]]\n else:\n matches = np.zeros((0, 3))\n\n n = matches.shape[0] > 0\n m0, m1, _ = matches.transpose().astype(int)\n for i, gc in enumerate(gt_classes):\n j = m0 == i\n if n and sum(j) == 1:\n self.matrix[detection_classes[m1[j].item()], gc] += 1 # correct\n else:\n self.matrix[self.nc, gc] += 1 # true background\n\n for i, dc in enumerate(detection_classes):\n if not any(m1 == i):\n self.matrix[dc, self.nc] += 1 # predicted background\n\n def matrix(self):\n \"\"\"Return the confusion matrix.\"\"\"\n return self.matrix\n\n def tp_fp(self) -> Tuple[np.ndarray, np.ndarray]:\n \"\"\"\n Return true positives and false positives.\n\n Returns:\n tp (np.ndarray): True positives.\n fp (np.ndarray): False positives.\n \"\"\"\n tp = self.matrix.diagonal() # true positives\n fp = self.matrix.sum(1) - tp # false positives\n # fn = self.matrix.sum(0) - tp # false negatives (missed detections)\n return (tp, fp) if self.task == \"classify\" else (tp[:-1], fp[:-1]) # remove background class if task=detect\n\n @TryExcept(msg=\"ConfusionMatrix plot failure\")\n @plt_settings()\n def plot(self, normalize: bool = True, save_dir: str = \"\", on_plot=None):\n \"\"\"\n Plot the confusion matrix using matplotlib and save it to a file.\n\n Args:\n normalize (bool, optional): Whether to normalize the confusion matrix.\n save_dir (str, optional): Directory where the plot will be saved.\n on_plot (callable, optional): An optional callback to pass plots path and data when they are rendered.\n \"\"\"\n import matplotlib.pyplot as plt # scope for faster 'import ultralytics'\n\n array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1e-9) if normalize else 1) # normalize columns\n array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)\n\n fig, ax = plt.subplots(1, 1, figsize=(12, 9))\n names, n = self.names, self.nc\n if self.nc >= 100: # downsample for large class count\n k = max(2, self.nc // 60) # step size for downsampling, always > 1\n keep_idx = slice(None, None, k) # create slice instead of array\n names = names[keep_idx] # slice class names\n array = array[keep_idx, :][:, keep_idx] # slice matrix rows and cols\n n = (self.nc + k - 1) // k # number of retained classes\n nc = nn = n if self.task == \"classify\" else n + 1 # adjust for background if needed\n ticklabels = (names + [\"background\"]) if (0 < nn < 99) and (nn == nc) else \"auto\"\n xy_ticks = np.arange(len(ticklabels))\n tick_fontsize = max(6, 15 - 0.1 * nc) # Minimum size is 6\n label_fontsize = max(6, 12 - 0.1 * nc)\n title_fontsize = max(6, 12 - 0.1 * nc)\n btm = max(0.1, 0.25 - 0.001 * nc) # Minimum value is 0.1\n with warnings.catch_warnings():\n warnings.simplefilter(\"ignore\") # suppress empty matrix RuntimeWarning: All-NaN slice encountered\n im = ax.imshow(array, cmap=\"Blues\", vmin=0.0, interpolation=\"none\")\n ax.xaxis.set_label_position(\"bottom\")\n if nc < 30: # Add score for each cell of confusion matrix\n color_threshold = 0.45 * (1 if normalize else np.nanmax(array)) # text color threshold\n for i, row in enumerate(array[:nc]):\n for j, val in enumerate(row[:nc]):\n val = array[i, j]\n if np.isnan(val):\n continue\n ax.text(\n j,\n i,\n f\"{val:.2f}\" if normalize else f\"{int(val)}\",\n ha=\"center\",\n va=\"center\",\n fontsize=10,\n color=\"white\" if val > color_threshold else \"black\",\n )\n cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.05)\n title = \"Confusion Matrix\" + \" Normalized\" * normalize\n ax.set_xlabel(\"True\", fontsize=label_fontsize, labelpad=10)\n ax.set_ylabel(\"Predicted\", fontsize=label_fontsize, labelpad=10)\n ax.set_title(title, fontsize=title_fontsize, pad=20)\n ax.set_xticks(xy_ticks)\n ax.set_yticks(xy_ticks)\n ax.tick_params(axis=\"x\", bottom=True, top=False, labelbottom=True, labeltop=False)\n ax.tick_params(axis=\"y\", left=True, right=False, labelleft=True, labelright=False)\n if ticklabels != \"auto\":\n ax.set_xticklabels(ticklabels, fontsize=tick_fontsize, rotation=90, ha=\"center\")\n ax.set_yticklabels(ticklabels, fontsize=tick_fontsize)\n for s in {\"left\", \"right\", \"bottom\", \"top\", \"outline\"}:\n if s != \"outline\":\n ax.spines[s].set_visible(False) # Confusion matrix plot don't have outline\n cbar.ax.spines[s].set_visible(False)\n fig.subplots_adjust(left=0, right=0.84, top=0.94, bottom=btm) # Adjust layout to ensure equal margins\n plot_fname = Path(save_dir) / f\"{title.lower().replace(' ', '_')}.png\"\n fig.savefig(plot_fname, dpi=250)\n plt.close(fig)\n if on_plot:\n on_plot(plot_fname)\n\n def print(self):\n \"\"\"Print the confusion matrix to the console.\"\"\"\n for i in range(self.matrix.shape[0]):\n LOGGER.info(\" \".join(map(str, self.matrix[i])))\n\n def summary(self, normalize: bool = False, decimals: int = 5) -> List[Dict[str, float]]:\n \"\"\"\n Generate a summarized representation of the confusion matrix as a list of dictionaries, with optional\n normalization. This is useful for exporting the matrix to various formats such as CSV, XML, HTML, JSON, or SQL.\n\n Args:\n normalize (bool): Whether to normalize the confusion matrix values.\n decimals (int): Number of decimal places to round the output values to.\n\n Returns:\n (List[Dict[str, float]]): A list of dictionaries, each representing one predicted class with corresponding values for all actual classes.\n\n Examples:\n >>> results = model.val(data=\"coco8.yaml\", plots=True)\n >>> cm_dict = results.confusion_matrix.summary(normalize=True, decimals=5)\n >>> print(cm_dict)\n \"\"\"\n import re\n\n names = self.names if self.task == \"classify\" else self.names + [\"background\"]\n clean_names, seen = [], set()\n for name in names:\n clean_name = re.sub(r\"[^a-zA-Z0-9_]\", \"_\", name)\n original_clean = clean_name\n counter = 1\n while clean_name.lower() in seen:\n clean_name = f\"{original_clean}_{counter}\"\n counter += 1\n seen.add(clean_name.lower())\n clean_names.append(clean_name)\n array = (self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1e-9) if normalize else 1)).round(decimals)\n return [\n dict({\"Predicted\": clean_names[i]}, **{clean_names[j]: array[i, j] for j in range(len(clean_names))})\n for i in range(len(clean_names))\n ]", "chunk_type": "class", "name": "ConfusionMatrix", "file_path": "ultralytics\\ultralytics\\utils\\metrics.py", "start_line": 312, "end_line": 540, "start_col": 0, "end_col": 9, "parent_name": null, "docstring": "A class for calculating and updating a confusion matrix for object detection and classification tasks.\n\nAttributes:\n task (str): The type of task, either 'detect' or 'classify'.\n matrix (np.ndarray): The confusion matrix, with dimensions depending on the task.\n nc (int): The number of category.\n names (List[str]): The names of the classes, used as labels on the plot.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "math", "warnings", "pathlib.Path", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "typing.Union", "numpy", "torch", "ultralytics.utils.LOGGER", "ultralytics.utils.DataExportMixin", "ultralytics.utils.SimpleClass", "ultralytics.utils.TryExcept", "ultralytics.utils.checks", "ultralytics.utils.plt_settings", "matplotlib.pyplot", "matplotlib.pyplot", "matplotlib.pyplot", "re", "DataExportMixin" ], "chunk_id": "class_ConfusionMatrix_f68dd0e5" }, { "content": "def smooth(y: np.ndarray, f: float = 0.05) -> np.ndarray:\n \"\"\"Box filter of fraction f.\"\"\"\n nf = round(len(y) * f * 2) // 2 + 1 # number of filter elements (must be odd)\n p = np.ones(nf // 2) # ones padding\n yp = np.concatenate((p * y[0], y, p * y[-1]), 0) # y padded\n return np.convolve(yp, np.ones(nf) / nf, mode=\"valid\") # y-smoothed", "chunk_type": "function", "name": "smooth", "file_path": "ultralytics\\ultralytics\\utils\\metrics.py", "start_line": 543, "end_line": 548, "start_col": 0, "end_col": 58, "parent_name": null, "docstring": "Box filter of fraction f.", "parameters": [ "y: np.ndarray", "f: float" ], "return_type": "np.ndarray", "decorators": [], "complexity_score": 1, "dependencies": [ "math", "warnings", "pathlib.Path", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "typing.Union", "numpy", "torch", "ultralytics.utils.LOGGER", "ultralytics.utils.DataExportMixin", "ultralytics.utils.SimpleClass", "ultralytics.utils.TryExcept", "ultralytics.utils.checks", "ultralytics.utils.plt_settings", "matplotlib.pyplot", "matplotlib.pyplot", "matplotlib.pyplot", "re" ], "chunk_id": "function_smooth_9012a130" }, { "content": "def plot_pr_curve(\n px: np.ndarray,\n py: np.ndarray,\n ap: np.ndarray,\n save_dir: Path = Path(\"pr_curve.png\"),\n names: Dict[int, str] = {},\n on_plot=None,\n):\n \"\"\"\n Plot precision-recall curve.\n\n Args:\n px (np.ndarray): X values for the PR curve.\n py (np.ndarray): Y values for the PR curve.\n ap (np.ndarray): Average precision values.\n save_dir (Path, optional): Path to save the plot.\n names (Dict[int, str], optional): Dictionary mapping class indices to class names.\n on_plot (callable, optional): Function to call after plot is saved.\n \"\"\"\n import matplotlib.pyplot as plt # scope for faster 'import ultralytics'\n\n fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)\n py = np.stack(py, axis=1)\n\n if 0 < len(names) < 21: # display per-class legend if < 21 classes\n for i, y in enumerate(py.T):\n ax.plot(px, y, linewidth=1, label=f\"{names[i]} {ap[i, 0]:.3f}\") # plot(recall, precision)\n else:\n ax.plot(px, py, linewidth=1, color=\"grey\") # plot(recall, precision)\n\n ax.plot(px, py.mean(1), linewidth=3, color=\"blue\", label=f\"all classes {ap[:, 0].mean():.3f} mAP@0.5\")\n ax.set_xlabel(\"Recall\")\n ax.set_ylabel(\"Precision\")\n ax.set_xlim(0, 1)\n ax.set_ylim(0, 1)\n ax.legend(bbox_to_anchor=(1.04, 1), loc=\"upper left\")\n ax.set_title(\"Precision-Recall Curve\")\n fig.savefig(save_dir, dpi=250)\n plt.close(fig)\n if on_plot:\n on_plot(save_dir)", "chunk_type": "function", "name": "plot_pr_curve", "file_path": "ultralytics\\ultralytics\\utils\\metrics.py", "start_line": 552, "end_line": 592, "start_col": 0, "end_col": 25, "parent_name": null, "docstring": "Plot precision-recall curve.\n\nArgs:\n px (np.ndarray): X values for the PR curve.\n py (np.ndarray): Y values for the PR curve.\n ap (np.ndarray): Average precision values.\n save_dir (Path, optional): Path to save the plot.\n names (Dict[int, str], optional): Dictionary mapping class indices to class names.\n on_plot (callable, optional): Function to call after plot is saved.", "parameters": [ "px: np.ndarray", "py: np.ndarray", "ap: np.ndarray", "save_dir: Path", "names: Dict[int, str]", "on_plot" ], "return_type": null, "decorators": [ "plt_settings()" ], "complexity_score": 4, "dependencies": [ "math", "warnings", "pathlib.Path", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "typing.Union", "numpy", "torch", "ultralytics.utils.LOGGER", "ultralytics.utils.DataExportMixin", "ultralytics.utils.SimpleClass", "ultralytics.utils.TryExcept", "ultralytics.utils.checks", "ultralytics.utils.plt_settings", "matplotlib.pyplot", "matplotlib.pyplot", "matplotlib.pyplot", "re" ], "chunk_id": "function_plot_pr_curve_590548b1" }, { "content": "def plot_mc_curve(\n px: np.ndarray,\n py: np.ndarray,\n save_dir: Path = Path(\"mc_curve.png\"),\n names: Dict[int, str] = {},\n xlabel: str = \"Confidence\",\n ylabel: str = \"Metric\",\n on_plot=None,\n):\n \"\"\"\n Plot metric-confidence curve.\n\n Args:\n px (np.ndarray): X values for the metric-confidence curve.\n py (np.ndarray): Y values for the metric-confidence curve.\n save_dir (Path, optional): Path to save the plot.\n names (Dict[int, str], optional): Dictionary mapping class indices to class names.\n xlabel (str, optional): X-axis label.\n ylabel (str, optional): Y-axis label.\n on_plot (callable, optional): Function to call after plot is saved.\n \"\"\"\n import matplotlib.pyplot as plt # scope for faster 'import ultralytics'\n\n fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)\n\n if 0 < len(names) < 21: # display per-class legend if < 21 classes\n for i, y in enumerate(py):\n ax.plot(px, y, linewidth=1, label=f\"{names[i]}\") # plot(confidence, metric)\n else:\n ax.plot(px, py.T, linewidth=1, color=\"grey\") # plot(confidence, metric)\n\n y = smooth(py.mean(0), 0.1)\n ax.plot(px, y, linewidth=3, color=\"blue\", label=f\"all classes {y.max():.2f} at {px[y.argmax()]:.3f}\")\n ax.set_xlabel(xlabel)\n ax.set_ylabel(ylabel)\n ax.set_xlim(0, 1)\n ax.set_ylim(0, 1)\n ax.legend(bbox_to_anchor=(1.04, 1), loc=\"upper left\")\n ax.set_title(f\"{ylabel}-Confidence Curve\")\n fig.savefig(save_dir, dpi=250)\n plt.close(fig)\n if on_plot:\n on_plot(save_dir)", "chunk_type": "function", "name": "plot_mc_curve", "file_path": "ultralytics\\ultralytics\\utils\\metrics.py", "start_line": 596, "end_line": 638, "start_col": 0, "end_col": 25, "parent_name": null, "docstring": "Plot metric-confidence curve.\n\nArgs:\n px (np.ndarray): X values for the metric-confidence curve.\n py (np.ndarray): Y values for the metric-confidence curve.\n save_dir (Path, optional): Path to save the plot.\n names (Dict[int, str], optional): Dictionary mapping class indices to class names.\n xlabel (str, optional): X-axis label.\n ylabel (str, optional): Y-axis label.\n on_plot (callable, optional): Function to call after plot is saved.", "parameters": [ "px: np.ndarray", "py: np.ndarray", "save_dir: Path", "names: Dict[int, str]", "xlabel: str", "ylabel: str", "on_plot" ], "return_type": null, "decorators": [ "plt_settings()" ], "complexity_score": 4, "dependencies": [ "math", "warnings", "pathlib.Path", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "typing.Union", "numpy", "torch", "ultralytics.utils.LOGGER", "ultralytics.utils.DataExportMixin", "ultralytics.utils.SimpleClass", "ultralytics.utils.TryExcept", "ultralytics.utils.checks", "ultralytics.utils.plt_settings", "matplotlib.pyplot", "matplotlib.pyplot", "matplotlib.pyplot", "re" ], "chunk_id": "function_plot_mc_curve_70749d7c" }, { "content": "def compute_ap(recall: List[float], precision: List[float]) -> Tuple[float, np.ndarray, np.ndarray]:\n \"\"\"\n Compute the average precision (AP) given the recall and precision curves.\n\n Args:\n recall (list): The recall curve.\n precision (list): The precision curve.\n\n Returns:\n ap (float): Average precision.\n mpre (np.ndarray): Precision envelope curve.\n mrec (np.ndarray): Modified recall curve with sentinel values added at the beginning and end.\n \"\"\"\n # Append sentinel values to beginning and end\n mrec = np.concatenate(([0.0], recall, [1.0]))\n mpre = np.concatenate(([1.0], precision, [0.0]))\n\n # Compute the precision envelope\n mpre = np.flip(np.maximum.accumulate(np.flip(mpre)))\n\n # Integrate area under curve\n method = \"interp\" # methods: 'continuous', 'interp'\n if method == \"interp\":\n x = np.linspace(0, 1, 101) # 101-point interp (COCO)\n func = np.trapezoid if checks.check_version(np.__version__, \">=2.0\") else np.trapz # np.trapz deprecated\n ap = func(np.interp(x, mrec, mpre), x) # integrate\n else: # 'continuous'\n i = np.where(mrec[1:] != mrec[:-1])[0] # points where x-axis (recall) changes\n ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) # area under curve\n\n return ap, mpre, mrec", "chunk_type": "function", "name": "compute_ap", "file_path": "ultralytics\\ultralytics\\utils\\metrics.py", "start_line": 641, "end_line": 671, "start_col": 0, "end_col": 25, "parent_name": null, "docstring": "Compute the average precision (AP) given the recall and precision curves.\n\nArgs:\n recall (list): The recall curve.\n precision (list): The precision curve.\n\nReturns:\n ap (float): Average precision.\n mpre (np.ndarray): Precision envelope curve.\n mrec (np.ndarray): Modified recall curve with sentinel values added at the beginning and end.", "parameters": [ "recall: List[float]", "precision: List[float]" ], "return_type": "Tuple[float, np.ndarray, np.ndarray]", "decorators": [], "complexity_score": 2, "dependencies": [ "math", "warnings", "pathlib.Path", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "typing.Union", "numpy", "torch", "ultralytics.utils.LOGGER", "ultralytics.utils.DataExportMixin", "ultralytics.utils.SimpleClass", "ultralytics.utils.TryExcept", "ultralytics.utils.checks", "ultralytics.utils.plt_settings", "matplotlib.pyplot", "matplotlib.pyplot", "matplotlib.pyplot", "re" ], "chunk_id": "function_compute_ap_0f423827" }, { "content": "def ap_per_class(\n tp: np.ndarray,\n conf: np.ndarray,\n pred_cls: np.ndarray,\n target_cls: np.ndarray,\n plot: bool = False,\n on_plot=None,\n save_dir: Path = Path(),\n names: Dict[int, str] = {},\n eps: float = 1e-16,\n prefix: str = \"\",\n) -> Tuple:\n \"\"\"\n Compute the average precision per class for object detection evaluation.\n\n Args:\n tp (np.ndarray): Binary array indicating whether the detection is correct (True) or not (False).\n conf (np.ndarray): Array of confidence scores of the detections.\n pred_cls (np.ndarray): Array of predicted classes of the detections.\n target_cls (np.ndarray): Array of true classes of the detections.\n plot (bool, optional): Whether to plot PR curves or not.\n on_plot (callable, optional): A callback to pass plots path and data when they are rendered.\n save_dir (Path, optional): Directory to save the PR curves.\n names (Dict[int, str], optional): Dictionary of class names to plot PR curves.\n eps (float, optional): A small value to avoid division by zero.\n prefix (str, optional): A prefix string for saving the plot files.\n\n Returns:\n tp (np.ndarray): True positive counts at threshold given by max F1 metric for each class.\n fp (np.ndarray): False positive counts at threshold given by max F1 metric for each class.\n p (np.ndarray): Precision values at threshold given by max F1 metric for each class.\n r (np.ndarray): Recall values at threshold given by max F1 metric for each class.\n f1 (np.ndarray): F1-score values at threshold given by max F1 metric for each class.\n ap (np.ndarray): Average precision for each class at different IoU thresholds.\n unique_classes (np.ndarray): An array of unique classes that have data.\n p_curve (np.ndarray): Precision curves for each class.\n r_curve (np.ndarray): Recall curves for each class.\n f1_curve (np.ndarray): F1-score curves for each class.\n x (np.ndarray): X-axis values for the curves.\n prec_values (np.ndarray): Precision values at mAP@0.5 for each class.\n \"\"\"\n # Sort by objectness\n i = np.argsort(-conf)\n tp, conf, pred_cls = tp[i], conf[i], pred_cls[i]\n\n # Find unique classes\n unique_classes, nt = np.unique(target_cls, return_counts=True)\n nc = unique_classes.shape[0] # number of classes, number of detections\n\n # Create Precision-Recall curve and compute AP for each class\n x, prec_values = np.linspace(0, 1, 1000), []\n\n # Average precision, precision and recall curves\n ap, p_curve, r_curve = np.zeros((nc, tp.shape[1])), np.zeros((nc, 1000)), np.zeros((nc, 1000))\n for ci, c in enumerate(unique_classes):\n i = pred_cls == c\n n_l = nt[ci] # number of labels\n n_p = i.sum() # number of predictions\n if n_p == 0 or n_l == 0:\n continue\n\n # Accumulate FPs and TPs\n fpc = (1 - tp[i]).cumsum(0)\n tpc = tp[i].cumsum(0)\n\n # Recall\n recall = tpc / (n_l + eps) # recall curve\n r_curve[ci] = np.interp(-x, -conf[i], recall[:, 0], left=0) # negative x, xp because xp decreases\n\n # Precision\n precision = tpc / (tpc + fpc) # precision curve\n p_curve[ci] = np.interp(-x, -conf[i], precision[:, 0], left=1) # p at pr_score\n\n # AP from recall-precision curve\n for j in range(tp.shape[1]):\n ap[ci, j], mpre, mrec = compute_ap(recall[:, j], precision[:, j])\n if j == 0:\n prec_values.append(np.interp(x, mrec, mpre)) # precision at mAP@0.5\n\n prec_values = np.array(prec_values) if prec_values else np.zeros((1, 1000)) # (nc, 1000)\n\n # Compute F1 (harmonic mean of precision and recall)\n f1_curve = 2 * p_curve * r_curve / (p_curve + r_curve + eps)\n names = {i: names[k] for i, k in enumerate(unique_classes) if k in names} # dict: only classes that have data\n if plot:\n plot_pr_curve(x, prec_values, ap, save_dir / f\"{prefix}PR_curve.png\", names, on_plot=on_plot)\n plot_mc_curve(x, f1_curve, save_dir / f\"{prefix}F1_curve.png\", names, ylabel=\"F1\", on_plot=on_plot)\n plot_mc_curve(x, p_curve, save_dir / f\"{prefix}P_curve.png\", names, ylabel=\"Precision\", on_plot=on_plot)\n plot_mc_curve(x, r_curve, save_dir / f\"{prefix}R_curve.png\", names, ylabel=\"Recall\", on_plot=on_plot)\n\n i = smooth(f1_curve.mean(0), 0.1).argmax() # max F1 index\n p, r, f1 = p_curve[:, i], r_curve[:, i], f1_curve[:, i] # max-F1 precision, recall, F1 values\n tp = (r * nt).round() # true positives\n fp = (tp / (p + eps) - tp).round() # false positives\n return tp, fp, p, r, f1, ap, unique_classes.astype(int), p_curve, r_curve, f1_curve, x, prec_values", "chunk_type": "function", "name": "ap_per_class", "file_path": "ultralytics\\ultralytics\\utils\\metrics.py", "start_line": 674, "end_line": 768, "start_col": 0, "end_col": 103, "parent_name": null, "docstring": "Compute the average precision per class for object detection evaluation.\n\nArgs:\n tp (np.ndarray): Binary array indicating whether the detection is correct (True) or not (False).\n conf (np.ndarray): Array of confidence scores of the detections.\n pred_cls (np.ndarray): Array of predicted classes of the detections.\n target_cls (np.ndarray): Array of true classes of the detections.\n plot (bool, optional): Whether to plot PR curves or not.\n on_plot (callable, optional): A callback to pass plots path and data when they are rendered.\n save_dir (Path, optional): Directory to save the PR curves.\n names (Dict[int, str], optional): Dictionary of class names to plot PR curves.\n eps (float, optional): A small value to avoid division by zero.\n prefix (str, optional): A prefix string for saving the plot files.\n\nReturns:\n tp (np.ndarray): True positive counts at threshold given by max F1 metric for each class.\n fp (np.ndarray): False positive counts at threshold given by max F1 metric for each class.\n p (np.ndarray): Precision values at threshold given by max F1 metric for each class.\n r (np.ndarray): Recall values at threshold given by max F1 metric for each class.\n f1 (np.ndarray): F1-score values at threshold given by max F1 metric for each class.\n ap (np.ndarray): Average precision for each class at different IoU thresholds.\n unique_classes (np.ndarray): An array of unique classes that have data.\n p_curve (np.ndarray): Precision curves for each class.\n r_curve (np.ndarray): Recall curves for each class.\n f1_curve (np.ndarray): F1-score curves for each class.\n x (np.ndarray): X-axis values for the curves.\n prec_values (np.ndarray): Precision values at mAP@0.5 for each class.", "parameters": [ "tp: np.ndarray", "conf: np.ndarray", "pred_cls: np.ndarray", "target_cls: np.ndarray", "plot: bool", "on_plot", "save_dir: Path", "names: Dict[int, str]", "eps: float", "prefix: str" ], "return_type": "Tuple", "decorators": [], "complexity_score": 7, "dependencies": [ "math", "warnings", "pathlib.Path", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "typing.Union", "numpy", "torch", "ultralytics.utils.LOGGER", "ultralytics.utils.DataExportMixin", "ultralytics.utils.SimpleClass", "ultralytics.utils.TryExcept", "ultralytics.utils.checks", "ultralytics.utils.plt_settings", "matplotlib.pyplot", "matplotlib.pyplot", "matplotlib.pyplot", "re" ], "chunk_id": "function_ap_per_class_3008df19" }, { "content": "class Metric(SimpleClass):\n \"\"\"\n Class for computing evaluation metrics for Ultralytics YOLO models.\n\n Attributes:\n p (list): Precision for each class. Shape: (nc,).\n r (list): Recall for each class. Shape: (nc,).\n f1 (list): F1 score for each class. Shape: (nc,).\n all_ap (list): AP scores for all classes and all IoU thresholds. Shape: (nc, 10).\n ap_class_index (list): Index of class for each AP score. Shape: (nc,).\n nc (int): Number of classes.\n\n Methods:\n ap50(): AP at IoU threshold of 0.5 for all classes. Returns: List of AP scores. Shape: (nc,) or [].\n ap(): AP at IoU thresholds from 0.5 to 0.95 for all classes. Returns: List of AP scores. Shape: (nc,) or [].\n mp(): Mean precision of all classes. Returns: Float.\n mr(): Mean recall of all classes. Returns: Float.\n map50(): Mean AP at IoU threshold of 0.5 for all classes. Returns: Float.\n map75(): Mean AP at IoU threshold of 0.75 for all classes. Returns: Float.\n map(): Mean AP at IoU thresholds from 0.5 to 0.95 for all classes. Returns: Float.\n mean_results(): Mean of results, returns mp, mr, map50, map.\n class_result(i): Class-aware result, returns p[i], r[i], ap50[i], ap[i].\n maps(): mAP of each class. Returns: Array of mAP scores, shape: (nc,).\n fitness(): Model fitness as a weighted combination of metrics. Returns: Float.\n update(results): Update metric attributes with new evaluation results.\n \"\"\"\n\n def __init__(self) -> None:\n \"\"\"Initialize a Metric instance for computing evaluation metrics for the YOLOv8 model.\"\"\"\n self.p = [] # (nc, )\n self.r = [] # (nc, )\n self.f1 = [] # (nc, )\n self.all_ap = [] # (nc, 10)\n self.ap_class_index = [] # (nc, )\n self.nc = 0\n\n @property\n def ap50(self) -> Union[np.ndarray, List]:\n \"\"\"\n Return the Average Precision (AP) at an IoU threshold of 0.5 for all classes.\n\n Returns:\n (np.ndarray | list): Array of shape (nc,) with AP50 values per class, or an empty list if not available.\n \"\"\"\n return self.all_ap[:, 0] if len(self.all_ap) else []\n\n @property\n def ap(self) -> Union[np.ndarray, List]:\n \"\"\"\n Return the Average Precision (AP) at an IoU threshold of 0.5-0.95 for all classes.\n\n Returns:\n (np.ndarray | list): Array of shape (nc,) with AP50-95 values per class, or an empty list if not available.\n \"\"\"\n return self.all_ap.mean(1) if len(self.all_ap) else []\n\n @property\n def mp(self) -> float:\n \"\"\"\n Return the Mean Precision of all classes.\n\n Returns:\n (float): The mean precision of all classes.\n \"\"\"\n return self.p.mean() if len(self.p) else 0.0\n\n @property\n def mr(self) -> float:\n \"\"\"\n Return the Mean Recall of all classes.\n\n Returns:\n (float): The mean recall of all classes.\n \"\"\"\n return self.r.mean() if len(self.r) else 0.0\n\n @property\n def map50(self) -> float:\n \"\"\"\n Return the mean Average Precision (mAP) at an IoU threshold of 0.5.\n\n Returns:\n (float): The mAP at an IoU threshold of 0.5.\n \"\"\"\n return self.all_ap[:, 0].mean() if len(self.all_ap) else 0.0\n\n @property\n def map75(self) -> float:\n \"\"\"\n Return the mean Average Precision (mAP) at an IoU threshold of 0.75.\n\n Returns:\n (float): The mAP at an IoU threshold of 0.75.\n \"\"\"\n return self.all_ap[:, 5].mean() if len(self.all_ap) else 0.0\n\n @property\n def map(self) -> float:\n \"\"\"\n Return the mean Average Precision (mAP) over IoU thresholds of 0.5 - 0.95 in steps of 0.05.\n\n Returns:\n (float): The mAP over IoU thresholds of 0.5 - 0.95 in steps of 0.05.\n \"\"\"\n return self.all_ap.mean() if len(self.all_ap) else 0.0\n\n def mean_results(self) -> List[float]:\n \"\"\"Return mean of results, mp, mr, map50, map.\"\"\"\n return [self.mp, self.mr, self.map50, self.map]\n\n def class_result(self, i: int) -> Tuple[float, float, float, float]:\n \"\"\"Return class-aware result, p[i], r[i], ap50[i], ap[i].\"\"\"\n return self.p[i], self.r[i], self.ap50[i], self.ap[i]\n\n @property\n def maps(self) -> np.ndarray:\n \"\"\"Return mAP of each class.\"\"\"\n maps = np.zeros(self.nc) + self.map\n for i, c in enumerate(self.ap_class_index):\n maps[c] = self.ap[i]\n return maps\n\n def fitness(self) -> float:\n \"\"\"Return model fitness as a weighted combination of metrics.\"\"\"\n w = [0.0, 0.0, 0.1, 0.9] # weights for [P, R, mAP@0.5, mAP@0.5:0.95]\n return (np.nan_to_num(np.array(self.mean_results())) * w).sum()\n\n def update(self, results: tuple):\n \"\"\"\n Update the evaluation metrics with a new set of results.\n\n Args:\n results (tuple): A tuple containing evaluation metrics:\n - p (list): Precision for each class.\n - r (list): Recall for each class.\n - f1 (list): F1 score for each class.\n - all_ap (list): AP scores for all classes and all IoU thresholds.\n - ap_class_index (list): Index of class for each AP score.\n - p_curve (list): Precision curve for each class.\n - r_curve (list): Recall curve for each class.\n - f1_curve (list): F1 curve for each class.\n - px (list): X values for the curves.\n - prec_values (list): Precision values for each class.\n \"\"\"\n (\n self.p,\n self.r,\n self.f1,\n self.all_ap,\n self.ap_class_index,\n self.p_curve,\n self.r_curve,\n self.f1_curve,\n self.px,\n self.prec_values,\n ) = results\n\n @property\n def curves(self) -> List:\n \"\"\"Return a list of curves for accessing specific metrics curves.\"\"\"\n return []\n\n @property\n def curves_results(self) -> List[List]:\n \"\"\"Return a list of curves for accessing specific metrics curves.\"\"\"\n return [\n [self.px, self.prec_values, \"Recall\", \"Precision\"],\n [self.px, self.f1_curve, \"Confidence\", \"F1\"],\n [self.px, self.p_curve, \"Confidence\", \"Precision\"],\n [self.px, self.r_curve, \"Confidence\", \"Recall\"],\n ]", "chunk_type": "class", "name": "Metric", "file_path": "ultralytics\\ultralytics\\utils\\metrics.py", "start_line": 771, "end_line": 941, "start_col": 0, "end_col": 9, "parent_name": null, "docstring": "Class for computing evaluation metrics for Ultralytics YOLO models.\n\nAttributes:\n p (list): Precision for each class. Shape: (nc,).\n r (list): Recall for each class. Shape: (nc,).\n f1 (list): F1 score for each class. Shape: (nc,).\n all_ap (list): AP scores for all classes and all IoU thresholds. Shape: (nc, 10).\n ap_class_index (list): Index of class for each AP score. Shape: (nc,).\n nc (int): Number of classes.\n\nMethods:\n ap50(): AP at IoU threshold of 0.5 for all classes. Returns: List of AP scores. Shape: (nc,) or [].\n ap(): AP at IoU thresholds from 0.5 to 0.95 for all classes. Returns: List of AP scores. Shape: (nc,) or [].\n mp(): Mean precision of all classes. Returns: Float.\n mr(): Mean recall of all classes. Returns: Float.\n map50(): Mean AP at IoU threshold of 0.5 for all classes. Returns: Float.\n map75(): Mean AP at IoU threshold of 0.75 for all classes. Returns: Float.\n map(): Mean AP at IoU thresholds from 0.5 to 0.95 for all classes. Returns: Float.\n mean_results(): Mean of results, returns mp, mr, map50, map.\n class_result(i): Class-aware result, returns p[i], r[i], ap50[i], ap[i].\n maps(): mAP of each class. Returns: Array of mAP scores, shape: (nc,).\n fitness(): Model fitness as a weighted combination of metrics. Returns: Float.\n update(results): Update metric attributes with new evaluation results.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "math", "warnings", "pathlib.Path", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "typing.Union", "numpy", "torch", "ultralytics.utils.LOGGER", "ultralytics.utils.DataExportMixin", "ultralytics.utils.SimpleClass", "ultralytics.utils.TryExcept", "ultralytics.utils.checks", "ultralytics.utils.plt_settings", "matplotlib.pyplot", "matplotlib.pyplot", "matplotlib.pyplot", "re", "SimpleClass" ], "chunk_id": "class_Metric_62f28f7e" }, { "content": "class DetMetrics(SimpleClass, DataExportMixin):\n \"\"\"\n Utility class for computing detection metrics such as precision, recall, and mean average precision (mAP).\n\n Attributes:\n names (Dict[int, str]): A dictionary of class names.\n box (Metric): An instance of the Metric class for storing detection results.\n speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process.\n task (str): The task type, set to 'detect'.\n stats (Dict[str, List]): A dictionary containing lists for true positives, confidence scores, predicted classes, target classes, and target images.\n nt_per_class: Number of targets per class.\n nt_per_image: Number of targets per image.\n \"\"\"\n\n def __init__(self, names: Dict[int, str] = {}) -> None:\n \"\"\"\n Initialize a DetMetrics instance with a save directory, plot flag, and class names.\n\n Args:\n names (Dict[int, str], optional): Dictionary of class names.\n \"\"\"\n self.names = names\n self.box = Metric()\n self.speed = {\"preprocess\": 0.0, \"inference\": 0.0, \"loss\": 0.0, \"postprocess\": 0.0}\n self.task = \"detect\"\n self.stats = dict(tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[])\n self.nt_per_class = None\n self.nt_per_image = None\n\n def update_stats(self, stat: Dict[str, Any]) -> None:\n \"\"\"\n Update statistics by appending new values to existing stat collections.\n\n Args:\n stat (Dict[str, any]): Dictionary containing new statistical values to append.\n Keys should match existing keys in self.stats.\n \"\"\"\n for k in self.stats.keys():\n self.stats[k].append(stat[k])\n\n def process(self, save_dir: Path = Path(\".\"), plot: bool = False, on_plot=None) -> Dict[str, np.ndarray]:\n \"\"\"\n Process predicted results for object detection and update metrics.\n\n Args:\n save_dir (Path): Directory to save plots. Defaults to Path(\".\").\n plot (bool): Whether to plot precision-recall curves. Defaults to False.\n on_plot (callable, optional): Function to call after plots are generated. Defaults to None.\n\n Returns:\n (Dict[str, np.ndarray]): Dictionary containing concatenated statistics arrays.\n \"\"\"\n stats = {k: np.concatenate(v, 0) for k, v in self.stats.items()} # to numpy\n if len(stats) == 0:\n return stats\n results = ap_per_class(\n stats[\"tp\"],\n stats[\"conf\"],\n stats[\"pred_cls\"],\n stats[\"target_cls\"],\n plot=plot,\n save_dir=save_dir,\n names=self.names,\n on_plot=on_plot,\n prefix=\"Box\",\n )[2:]\n self.box.nc = len(self.names)\n self.box.update(results)\n self.nt_per_class = np.bincount(stats[\"target_cls\"].astype(int), minlength=len(self.names))\n self.nt_per_image = np.bincount(stats[\"target_img\"].astype(int), minlength=len(self.names))\n return stats\n\n def clear_stats(self):\n \"\"\"Clear the stored statistics.\"\"\"\n for v in self.stats.values():\n v.clear()\n\n @property\n def keys(self) -> List[str]:\n \"\"\"Return a list of keys for accessing specific metrics.\"\"\"\n return [\"metrics/precision(B)\", \"metrics/recall(B)\", \"metrics/mAP50(B)\", \"metrics/mAP50-95(B)\"]\n\n def mean_results(self) -> List[float]:\n \"\"\"Calculate mean of detected objects & return precision, recall, mAP50, and mAP50-95.\"\"\"\n return self.box.mean_results()\n\n def class_result(self, i: int) -> Tuple[float, float, float, float]:\n \"\"\"Return the result of evaluating the performance of an object detection model on a specific class.\"\"\"\n return self.box.class_result(i)\n\n @property\n def maps(self) -> np.ndarray:\n \"\"\"Return mean Average Precision (mAP) scores per class.\"\"\"\n return self.box.maps\n\n @property\n def fitness(self) -> float:\n \"\"\"Return the fitness of box object.\"\"\"\n return self.box.fitness()\n\n @property\n def ap_class_index(self) -> List:\n \"\"\"Return the average precision index per class.\"\"\"\n return self.box.ap_class_index\n\n @property\n def results_dict(self) -> Dict[str, float]:\n \"\"\"Return dictionary of computed performance metrics and statistics.\"\"\"\n return dict(zip(self.keys + [\"fitness\"], self.mean_results() + [self.fitness]))\n\n @property\n def curves(self) -> List[str]:\n \"\"\"Return a list of curves for accessing specific metrics curves.\"\"\"\n return [\"Precision-Recall(B)\", \"F1-Confidence(B)\", \"Precision-Confidence(B)\", \"Recall-Confidence(B)\"]\n\n @property\n def curves_results(self) -> List[List]:\n \"\"\"Return dictionary of computed performance metrics and statistics.\"\"\"\n return self.box.curves_results\n\n def summary(self, normalize: bool = True, decimals: int = 5) -> List[Dict[str, Any]]:\n \"\"\"\n Generate a summarized representation of per-class detection metrics as a list of dictionaries. Includes shared\n scalar metrics (mAP, mAP50, mAP75) alongside precision, recall, and F1-score for each class.\n\n Args:\n normalize (bool): For Detect metrics, everything is normalized by default [0-1].\n decimals (int): Number of decimal places to round the metrics values to.\n\n Returns:\n (List[Dict[str, Any]]): A list of dictionaries, each representing one class with corresponding metric values.\n\n Examples:\n >>> results = model.val(data=\"coco8.yaml\")\n >>> detection_summary = results.summary()\n >>> print(detection_summary)\n \"\"\"\n per_class = {\n \"Box-P\": self.box.p,\n \"Box-R\": self.box.r,\n \"Box-F1\": self.box.f1,\n }\n return [\n {\n \"Class\": self.names[self.ap_class_index[i]],\n \"Images\": self.nt_per_image[self.ap_class_index[i]],\n \"Instances\": self.nt_per_class[self.ap_class_index[i]],\n **{k: round(v[i], decimals) for k, v in per_class.items()},\n \"mAP50\": round(self.class_result(i)[2], decimals),\n \"mAP50-95\": round(self.class_result(i)[3], decimals),\n }\n for i in range(len(per_class[\"Box-P\"]))\n ]", "chunk_type": "class", "name": "DetMetrics", "file_path": "ultralytics\\ultralytics\\utils\\metrics.py", "start_line": 944, "end_line": 1096, "start_col": 0, "end_col": 9, "parent_name": null, "docstring": "Utility class for computing detection metrics such as precision, recall, and mean average precision (mAP).\n\nAttributes:\n names (Dict[int, str]): A dictionary of class names.\n box (Metric): An instance of the Metric class for storing detection results.\n speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process.\n task (str): The task type, set to 'detect'.\n stats (Dict[str, List]): A dictionary containing lists for true positives, confidence scores, predicted classes, target classes, and target images.\n nt_per_class: Number of targets per class.\n nt_per_image: Number of targets per image.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "math", "warnings", "pathlib.Path", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "typing.Union", "numpy", "torch", "ultralytics.utils.LOGGER", "ultralytics.utils.DataExportMixin", "ultralytics.utils.SimpleClass", "ultralytics.utils.TryExcept", "ultralytics.utils.checks", "ultralytics.utils.plt_settings", "matplotlib.pyplot", "matplotlib.pyplot", "matplotlib.pyplot", "re", "SimpleClass", "DataExportMixin" ], "chunk_id": "class_DetMetrics_d7bfdbda" }, { "content": "class SegmentMetrics(DetMetrics):\n \"\"\"\n Calculate and aggregate detection and segmentation metrics over a given set of classes.\n\n Attributes:\n names (Dict[int, str]): Dictionary of class names.\n box (Metric): An instance of the Metric class for storing detection results.\n seg (Metric): An instance of the Metric class to calculate mask segmentation metrics.\n speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process.\n task (str): The task type, set to 'segment'.\n stats (Dict[str, List]): A dictionary containing lists for true positives, confidence scores, predicted classes, target classes, and target images.\n nt_per_class: Number of targets per class.\n nt_per_image: Number of targets per image.\n \"\"\"\n\n def __init__(self, names: Dict[int, str] = {}) -> None:\n \"\"\"\n Initialize a SegmentMetrics instance with a save directory, plot flag, and class names.\n\n Args:\n names (Dict[int, str], optional): Dictionary of class names.\n \"\"\"\n DetMetrics.__init__(self, names)\n self.seg = Metric()\n self.task = \"segment\"\n self.stats[\"tp_m\"] = [] # add additional stats for masks\n\n def process(self, save_dir: Path = Path(\".\"), plot: bool = False, on_plot=None) -> Dict[str, np.ndarray]:\n \"\"\"\n Process the detection and segmentation metrics over the given set of predictions.\n\n Args:\n save_dir (Path): Directory to save plots. Defaults to Path(\".\").\n plot (bool): Whether to plot precision-recall curves. Defaults to False.\n on_plot (callable, optional): Function to call after plots are generated. Defaults to None.\n\n Returns:\n (Dict[str, np.ndarray]): Dictionary containing concatenated statistics arrays.\n \"\"\"\n stats = DetMetrics.process(self, save_dir, plot, on_plot=on_plot) # process box stats\n results_mask = ap_per_class(\n stats[\"tp_m\"],\n stats[\"conf\"],\n stats[\"pred_cls\"],\n stats[\"target_cls\"],\n plot=plot,\n on_plot=on_plot,\n save_dir=save_dir,\n names=self.names,\n prefix=\"Mask\",\n )[2:]\n self.seg.nc = len(self.names)\n self.seg.update(results_mask)\n return stats\n\n @property\n def keys(self) -> List[str]:\n \"\"\"Return a list of keys for accessing metrics.\"\"\"\n return DetMetrics.keys.fget(self) + [\n \"metrics/precision(M)\",\n \"metrics/recall(M)\",\n \"metrics/mAP50(M)\",\n \"metrics/mAP50-95(M)\",\n ]\n\n def mean_results(self) -> List[float]:\n \"\"\"Return the mean metrics for bounding box and segmentation results.\"\"\"\n return DetMetrics.mean_results(self) + self.seg.mean_results()\n\n def class_result(self, i: int) -> List[float]:\n \"\"\"Return classification results for a specified class index.\"\"\"\n return DetMetrics.class_result(self, i) + self.seg.class_result(i)\n\n @property\n def maps(self) -> np.ndarray:\n \"\"\"Return mAP scores for object detection and semantic segmentation models.\"\"\"\n return DetMetrics.maps.fget(self) + self.seg.maps\n\n @property\n def fitness(self) -> float:\n \"\"\"Return the fitness score for both segmentation and bounding box models.\"\"\"\n return self.seg.fitness() + DetMetrics.fitness.fget(self)\n\n @property\n def curves(self) -> List[str]:\n \"\"\"Return a list of curves for accessing specific metrics curves.\"\"\"\n return DetMetrics.curves.fget(self) + [\n \"Precision-Recall(M)\",\n \"F1-Confidence(M)\",\n \"Precision-Confidence(M)\",\n \"Recall-Confidence(M)\",\n ]\n\n @property\n def curves_results(self) -> List[List]:\n \"\"\"Return dictionary of computed performance metrics and statistics.\"\"\"\n return DetMetrics.curves_results.fget(self) + self.seg.curves_results\n\n def summary(self, normalize: bool = True, decimals: int = 5) -> List[Dict[str, Any]]:\n \"\"\"\n Generate a summarized representation of per-class segmentation metrics as a list of dictionaries. Includes both\n box and mask scalar metrics (mAP, mAP50, mAP75) alongside precision, recall, and F1-score for each class.\n\n Args:\n normalize (bool): For Segment metrics, everything is normalized by default [0-1].\n decimals (int): Number of decimal places to round the metrics values to.\n\n Returns:\n (List[Dict[str, Any]]): A list of dictionaries, each representing one class with corresponding metric values.\n\n Examples:\n >>> results = model.val(data=\"coco8-seg.yaml\")\n >>> seg_summary = results.summary(decimals=4)\n >>> print(seg_summary)\n \"\"\"\n per_class = {\n \"Mask-P\": self.seg.p,\n \"Mask-R\": self.seg.r,\n \"Mask-F1\": self.seg.f1,\n }\n summary = DetMetrics.summary(self, normalize, decimals) # get box summary\n for i, s in enumerate(summary):\n s.update({**{k: round(v[i], decimals) for k, v in per_class.items()}})\n return summary", "chunk_type": "class", "name": "SegmentMetrics", "file_path": "ultralytics\\ultralytics\\utils\\metrics.py", "start_line": 1099, "end_line": 1222, "start_col": 0, "end_col": 22, "parent_name": null, "docstring": "Calculate and aggregate detection and segmentation metrics over a given set of classes.\n\nAttributes:\n names (Dict[int, str]): Dictionary of class names.\n box (Metric): An instance of the Metric class for storing detection results.\n seg (Metric): An instance of the Metric class to calculate mask segmentation metrics.\n speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process.\n task (str): The task type, set to 'segment'.\n stats (Dict[str, List]): A dictionary containing lists for true positives, confidence scores, predicted classes, target classes, and target images.\n nt_per_class: Number of targets per class.\n nt_per_image: Number of targets per image.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "math", "warnings", "pathlib.Path", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "typing.Union", "numpy", "torch", "ultralytics.utils.LOGGER", "ultralytics.utils.DataExportMixin", "ultralytics.utils.SimpleClass", "ultralytics.utils.TryExcept", "ultralytics.utils.checks", "ultralytics.utils.plt_settings", "matplotlib.pyplot", "matplotlib.pyplot", "matplotlib.pyplot", "re", "DetMetrics" ], "chunk_id": "class_SegmentMetrics_97d3182e" }, { "content": "class PoseMetrics(DetMetrics):\n \"\"\"\n Calculate and aggregate detection and pose metrics over a given set of classes.\n\n Attributes:\n names (Dict[int, str]): Dictionary of class names.\n pose (Metric): An instance of the Metric class to calculate pose metrics.\n box (Metric): An instance of the Metric class for storing detection results.\n speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process.\n task (str): The task type, set to 'pose'.\n stats (Dict[str, List]): A dictionary containing lists for true positives, confidence scores, predicted classes, target classes, and target images.\n nt_per_class: Number of targets per class.\n nt_per_image: Number of targets per image.\n\n Methods:\n process(tp_m, tp_b, conf, pred_cls, target_cls): Process metrics over the given set of predictions.\n mean_results(): Return the mean of the detection and segmentation metrics over all the classes.\n class_result(i): Return the detection and segmentation metrics of class `i`.\n maps: Return the mean Average Precision (mAP) scores for IoU thresholds ranging from 0.50 to 0.95.\n fitness: Return the fitness scores, which are a single weighted combination of metrics.\n ap_class_index: Return the list of indices of classes used to compute Average Precision (AP).\n results_dict: Return the dictionary containing all the detection and segmentation metrics and fitness score.\n \"\"\"\n\n def __init__(self, names: Dict[int, str] = {}) -> None:\n \"\"\"\n Initialize the PoseMetrics class with directory path, class names, and plotting options.\n\n Args:\n names (Dict[int, str], optional): Dictionary of class names.\n \"\"\"\n super().__init__(names)\n self.pose = Metric()\n self.task = \"pose\"\n self.stats[\"tp_p\"] = [] # add additional stats for pose\n\n def process(self, save_dir: Path = Path(\".\"), plot: bool = False, on_plot=None) -> Dict[str, np.ndarray]:\n \"\"\"\n Process the detection and pose metrics over the given set of predictions.\n\n Args:\n save_dir (Path): Directory to save plots. Defaults to Path(\".\").\n plot (bool): Whether to plot precision-recall curves. Defaults to False.\n on_plot (callable, optional): Function to call after plots are generated.\n\n Returns:\n (Dict[str, np.ndarray]): Dictionary containing concatenated statistics arrays.\n \"\"\"\n stats = DetMetrics.process(self, save_dir, plot, on_plot=on_plot) # process box stats\n results_pose = ap_per_class(\n stats[\"tp_p\"],\n stats[\"conf\"],\n stats[\"pred_cls\"],\n stats[\"target_cls\"],\n plot=plot,\n on_plot=on_plot,\n save_dir=save_dir,\n names=self.names,\n prefix=\"Pose\",\n )[2:]\n self.pose.nc = len(self.names)\n self.pose.update(results_pose)\n return stats\n\n @property\n def keys(self) -> List[str]:\n \"\"\"Return list of evaluation metric keys.\"\"\"\n return DetMetrics.keys.fget(self) + [\n \"metrics/precision(P)\",\n \"metrics/recall(P)\",\n \"metrics/mAP50(P)\",\n \"metrics/mAP50-95(P)\",\n ]\n\n def mean_results(self) -> List[float]:\n \"\"\"Return the mean results of box and pose.\"\"\"\n return DetMetrics.mean_results(self) + self.pose.mean_results()\n\n def class_result(self, i: int) -> List[float]:\n \"\"\"Return the class-wise detection results for a specific class i.\"\"\"\n return DetMetrics.class_result(self, i) + self.pose.class_result(i)\n\n @property\n def maps(self) -> np.ndarray:\n \"\"\"Return the mean average precision (mAP) per class for both box and pose detections.\"\"\"\n return DetMetrics.maps.fget(self) + self.pose.maps\n\n @property\n def fitness(self) -> float:\n \"\"\"Return combined fitness score for pose and box detection.\"\"\"\n return self.pose.fitness() + DetMetrics.fitness.fget(self)\n\n @property\n def curves(self) -> List[str]:\n \"\"\"Return a list of curves for accessing specific metrics curves.\"\"\"\n return DetMetrics.curves.fget(self) + [\n \"Precision-Recall(B)\",\n \"F1-Confidence(B)\",\n \"Precision-Confidence(B)\",\n \"Recall-Confidence(B)\",\n \"Precision-Recall(P)\",\n \"F1-Confidence(P)\",\n \"Precision-Confidence(P)\",\n \"Recall-Confidence(P)\",\n ]\n\n @property\n def curves_results(self) -> List[List]:\n \"\"\"Return dictionary of computed performance metrics and statistics.\"\"\"\n return DetMetrics.curves_results.fget(self) + self.pose.curves_results\n\n def summary(self, normalize: bool = True, decimals: int = 5) -> List[Dict[str, Any]]:\n \"\"\"\n Generate a summarized representation of per-class pose metrics as a list of dictionaries. Includes both box and\n pose scalar metrics (mAP, mAP50, mAP75) alongside precision, recall, and F1-score for each class.\n\n Args:\n normalize (bool): For Pose metrics, everything is normalized by default [0-1].\n decimals (int): Number of decimal places to round the metrics values to.\n\n Returns:\n (List[Dict[str, Any]]): A list of dictionaries, each representing one class with corresponding metric values.\n\n Examples:\n >>> results = model.val(data=\"coco8-pose.yaml\")\n >>> pose_summary = results.summary(decimals=4)\n >>> print(pose_summary)\n \"\"\"\n per_class = {\n \"Pose-P\": self.pose.p,\n \"Pose-R\": self.pose.r,\n \"Pose-F1\": self.pose.f1,\n }\n summary = DetMetrics.summary(self, normalize, decimals) # get box summary\n for i, s in enumerate(summary):\n s.update({**{k: round(v[i], decimals) for k, v in per_class.items()}})\n return summary", "chunk_type": "class", "name": "PoseMetrics", "file_path": "ultralytics\\ultralytics\\utils\\metrics.py", "start_line": 1225, "end_line": 1361, "start_col": 0, "end_col": 22, "parent_name": null, "docstring": "Calculate and aggregate detection and pose metrics over a given set of classes.\n\nAttributes:\n names (Dict[int, str]): Dictionary of class names.\n pose (Metric): An instance of the Metric class to calculate pose metrics.\n box (Metric): An instance of the Metric class for storing detection results.\n speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process.\n task (str): The task type, set to 'pose'.\n stats (Dict[str, List]): A dictionary containing lists for true positives, confidence scores, predicted classes, target classes, and target images.\n nt_per_class: Number of targets per class.\n nt_per_image: Number of targets per image.\n\nMethods:\n process(tp_m, tp_b, conf, pred_cls, target_cls): Process metrics over the given set of predictions.\n mean_results(): Return the mean of the detection and segmentation metrics over all the classes.\n class_result(i): Return the detection and segmentation metrics of class `i`.\n maps: Return the mean Average Precision (mAP) scores for IoU thresholds ranging from 0.50 to 0.95.\n fitness: Return the fitness scores, which are a single weighted combination of metrics.\n ap_class_index: Return the list of indices of classes used to compute Average Precision (AP).\n results_dict: Return the dictionary containing all the detection and segmentation metrics and fitness score.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "math", "warnings", "pathlib.Path", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "typing.Union", "numpy", "torch", "ultralytics.utils.LOGGER", "ultralytics.utils.DataExportMixin", "ultralytics.utils.SimpleClass", "ultralytics.utils.TryExcept", "ultralytics.utils.checks", "ultralytics.utils.plt_settings", "matplotlib.pyplot", "matplotlib.pyplot", "matplotlib.pyplot", "re", "DetMetrics" ], "chunk_id": "class_PoseMetrics_25476aa9" }, { "content": "class ClassifyMetrics(SimpleClass, DataExportMixin):\n \"\"\"\n Class for computing classification metrics including top-1 and top-5 accuracy.\n\n Attributes:\n top1 (float): The top-1 accuracy.\n top5 (float): The top-5 accuracy.\n speed (dict): A dictionary containing the time taken for each step in the pipeline.\n task (str): The task type, set to 'classify'.\n \"\"\"\n\n def __init__(self) -> None:\n \"\"\"Initialize a ClassifyMetrics instance.\"\"\"\n self.top1 = 0\n self.top5 = 0\n self.speed = {\"preprocess\": 0.0, \"inference\": 0.0, \"loss\": 0.0, \"postprocess\": 0.0}\n self.task = \"classify\"\n\n def process(self, targets: torch.Tensor, pred: torch.Tensor):\n \"\"\"\n Process target classes and predicted classes to compute metrics.\n\n Args:\n targets (torch.Tensor): Target classes.\n pred (torch.Tensor): Predicted classes.\n \"\"\"\n pred, targets = torch.cat(pred), torch.cat(targets)\n correct = (targets[:, None] == pred).float()\n acc = torch.stack((correct[:, 0], correct.max(1).values), dim=1) # (top1, top5) accuracy\n self.top1, self.top5 = acc.mean(0).tolist()\n\n @property\n def fitness(self) -> float:\n \"\"\"Return mean of top-1 and top-5 accuracies as fitness score.\"\"\"\n return (self.top1 + self.top5) / 2\n\n @property\n def results_dict(self) -> Dict[str, float]:\n \"\"\"Return a dictionary with model's performance metrics and fitness score.\"\"\"\n return dict(zip(self.keys + [\"fitness\"], [self.top1, self.top5, self.fitness]))\n\n @property\n def keys(self) -> List[str]:\n \"\"\"Return a list of keys for the results_dict property.\"\"\"\n return [\"metrics/accuracy_top1\", \"metrics/accuracy_top5\"]\n\n @property\n def curves(self) -> List:\n \"\"\"Return a list of curves for accessing specific metrics curves.\"\"\"\n return []\n\n @property\n def curves_results(self) -> List:\n \"\"\"Return a list of curves for accessing specific metrics curves.\"\"\"\n return []\n\n def summary(self, normalize: bool = True, decimals: int = 5) -> List[Dict[str, float]]:\n \"\"\"\n Generate a single-row summary of classification metrics (Top-1 and Top-5 accuracy).\n\n Args:\n normalize (bool): For Classify metrics, everything is normalized by default [0-1].\n decimals (int): Number of decimal places to round the metrics values to.\n\n Returns:\n (List[Dict[str, float]]): A list with one dictionary containing Top-1 and Top-5 classification accuracy.\n\n Examples:\n >>> results = model.val(data=\"imagenet10\")\n >>> classify_summary = results.summary(decimals=4)\n >>> print(classify_summary)\n \"\"\"\n return [{\"top1_acc\": round(self.top1, decimals), \"top5_acc\": round(self.top5, decimals)}]", "chunk_type": "class", "name": "ClassifyMetrics", "file_path": "ultralytics\\ultralytics\\utils\\metrics.py", "start_line": 1364, "end_line": 1436, "start_col": 0, "end_col": 97, "parent_name": null, "docstring": "Class for computing classification metrics including top-1 and top-5 accuracy.\n\nAttributes:\n top1 (float): The top-1 accuracy.\n top5 (float): The top-5 accuracy.\n speed (dict): A dictionary containing the time taken for each step in the pipeline.\n task (str): The task type, set to 'classify'.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "math", "warnings", "pathlib.Path", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "typing.Union", "numpy", "torch", "ultralytics.utils.LOGGER", "ultralytics.utils.DataExportMixin", "ultralytics.utils.SimpleClass", "ultralytics.utils.TryExcept", "ultralytics.utils.checks", "ultralytics.utils.plt_settings", "matplotlib.pyplot", "matplotlib.pyplot", "matplotlib.pyplot", "re", "SimpleClass", "DataExportMixin" ], "chunk_id": "class_ClassifyMetrics_2b5ec4ae" }, { "content": "class OBBMetrics(DetMetrics):\n \"\"\"\n Metrics for evaluating oriented bounding box (OBB) detection.\n\n Attributes:\n names (Dict[int, str]): Dictionary of class names.\n box (Metric): An instance of the Metric class for storing detection results.\n speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process.\n task (str): The task type, set to 'obb'.\n stats (Dict[str, List]): A dictionary containing lists for true positives, confidence scores, predicted classes, target classes, and target images.\n nt_per_class: Number of targets per class.\n nt_per_image: Number of targets per image.\n\n References:\n https://arxiv.org/pdf/2106.06072.pdf\n \"\"\"\n\n def __init__(self, names: Dict[int, str] = {}) -> None:\n \"\"\"\n Initialize an OBBMetrics instance with directory, plotting, and class names.\n\n Args:\n names (Dict[int, str], optional): Dictionary of class names.\n \"\"\"\n DetMetrics.__init__(self, names)\n # TODO: probably remove task as well\n self.task = \"obb\"", "chunk_type": "class", "name": "OBBMetrics", "file_path": "ultralytics\\ultralytics\\utils\\metrics.py", "start_line": 1439, "end_line": 1465, "start_col": 0, "end_col": 25, "parent_name": null, "docstring": "Metrics for evaluating oriented bounding box (OBB) detection.\n\nAttributes:\n names (Dict[int, str]): Dictionary of class names.\n box (Metric): An instance of the Metric class for storing detection results.\n speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process.\n task (str): The task type, set to 'obb'.\n stats (Dict[str, List]): A dictionary containing lists for true positives, confidence scores, predicted classes, target classes, and target images.\n nt_per_class: Number of targets per class.\n nt_per_image: Number of targets per image.\n\nReferences:\n https://arxiv.org/pdf/2106.06072.pdf", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "math", "warnings", "pathlib.Path", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "typing.Union", "numpy", "torch", "ultralytics.utils.LOGGER", "ultralytics.utils.DataExportMixin", "ultralytics.utils.SimpleClass", "ultralytics.utils.TryExcept", "ultralytics.utils.checks", "ultralytics.utils.plt_settings", "matplotlib.pyplot", "matplotlib.pyplot", "matplotlib.pyplot", "re", "DetMetrics" ], "chunk_id": "class_OBBMetrics_cccc183b" }, { "content": "import contextlib", "chunk_type": "import", "name": "contextlib", "file_path": "ultralytics\\ultralytics\\utils\\ops.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 17, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_contextlib_7046437c" }, { "content": "import math", "chunk_type": "import", "name": "math", "file_path": "ultralytics\\ultralytics\\utils\\ops.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 11, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_math_6b55b476" }, { "content": "import re", "chunk_type": "import", "name": "re", "file_path": "ultralytics\\ultralytics\\utils\\ops.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 9, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_re_30a394c5" }, { "content": "import time", "chunk_type": "import", "name": "time", "file_path": "ultralytics\\ultralytics\\utils\\ops.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 11, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_time_55bf7bb5" }, { "content": "from typing import Optional", "chunk_type": "import", "name": "Optional", "file_path": "ultralytics\\ultralytics\\utils\\ops.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 27, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Optional_fac61e14" }, { "content": "import cv2", "chunk_type": "import", "name": "cv2", "file_path": "ultralytics\\ultralytics\\utils\\ops.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 10, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_cv2_0fea62e2" }, { "content": "import numpy as np", "chunk_type": "import", "name": "numpy", "file_path": "ultralytics\\ultralytics\\utils\\ops.py", "start_line": 10, "end_line": 10, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_numpy_3ae73d6e" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\ultralytics\\utils\\ops.py", "start_line": 11, "end_line": 11, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_790c7dfa" }, { "content": "import torch.nn.functional as F", "chunk_type": "import", "name": "torch.nn.functional", "file_path": "ultralytics\\ultralytics\\utils\\ops.py", "start_line": 12, "end_line": 12, "start_col": 0, "end_col": 31, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch.nn.functional_4e111fa3" }, { "content": "from ultralytics.utils import LOGGER", "chunk_type": "import", "name": "LOGGER", "file_path": "ultralytics\\ultralytics\\utils\\ops.py", "start_line": 14, "end_line": 14, "start_col": 0, "end_col": 36, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_LOGGER_131edebf" }, { "content": "from ultralytics.utils.metrics import batch_probiou", "chunk_type": "import", "name": "batch_probiou", "file_path": "ultralytics\\ultralytics\\utils\\ops.py", "start_line": 15, "end_line": 15, "start_col": 0, "end_col": 51, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_batch_probiou_5553b42e" }, { "content": "class Profile(contextlib.ContextDecorator):\n \"\"\"\n Ultralytics Profile class for timing code execution.\n\n Use as a decorator with @Profile() or as a context manager with 'with Profile():'. Provides accurate timing\n measurements with CUDA synchronization support for GPU operations.\n\n Attributes:\n t (float): Accumulated time in seconds.\n device (torch.device): Device used for model inference.\n cuda (bool): Whether CUDA is being used for timing synchronization.\n\n Examples:\n Use as a context manager to time code execution\n >>> with Profile(device=device) as dt:\n ... pass # slow operation here\n >>> print(dt) # prints \"Elapsed time is 9.5367431640625e-07 s\"\n\n Use as a decorator to time function execution\n >>> @Profile()\n ... def slow_function():\n ... time.sleep(0.1)\n \"\"\"\n\n def __init__(self, t: float = 0.0, device: Optional[torch.device] = None):\n \"\"\"\n Initialize the Profile class.\n\n Args:\n t (float): Initial accumulated time in seconds.\n device (torch.device, optional): Device used for model inference to enable CUDA synchronization.\n \"\"\"\n self.t = t\n self.device = device\n self.cuda = bool(device and str(device).startswith(\"cuda\"))\n\n def __enter__(self):\n \"\"\"Start timing.\"\"\"\n self.start = self.time()\n return self\n\n def __exit__(self, type, value, traceback): # noqa\n \"\"\"Stop timing.\"\"\"\n self.dt = self.time() - self.start # delta-time\n self.t += self.dt # accumulate dt\n\n def __str__(self):\n \"\"\"Return a human-readable string representing the accumulated elapsed time.\"\"\"\n return f\"Elapsed time is {self.t} s\"\n\n def time(self):\n \"\"\"Get current time with CUDA synchronization if applicable.\"\"\"\n if self.cuda:\n torch.cuda.synchronize(self.device)\n return time.perf_counter()", "chunk_type": "class", "name": "Profile", "file_path": "ultralytics\\ultralytics\\utils\\ops.py", "start_line": 18, "end_line": 72, "start_col": 0, "end_col": 34, "parent_name": null, "docstring": "Ultralytics Profile class for timing code execution.\n\nUse as a decorator with @Profile() or as a context manager with 'with Profile():'. Provides accurate timing\nmeasurements with CUDA synchronization support for GPU operations.\n\nAttributes:\n t (float): Accumulated time in seconds.\n device (torch.device): Device used for model inference.\n cuda (bool): Whether CUDA is being used for timing synchronization.\n\nExamples:\n Use as a context manager to time code execution\n >>> with Profile(device=device) as dt:\n ... pass # slow operation here\n >>> print(dt) # prints \"Elapsed time is 9.5367431640625e-07 s\"\n\n Use as a decorator to time function execution\n >>> @Profile()\n ... def slow_function():\n ... time.sleep(0.1)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "contextlib", "math", "re", "time", "typing.Optional", "cv2", "numpy", "torch", "torch.nn.functional", "ultralytics.utils.LOGGER", "ultralytics.utils.metrics.batch_probiou", "torchvision", "ultralytics.data.converter.merge_multi_segment", "contextlib.ContextDecorator" ], "chunk_id": "class_Profile_6b5e8a32" }, { "content": "def segment2box(segment, width: int = 640, height: int = 640):\n \"\"\"\n Convert segment coordinates to bounding box coordinates.\n\n Converts a single segment label to a box label by finding the minimum and maximum x and y coordinates.\n Applies inside-image constraint and clips coordinates when necessary.\n\n Args:\n segment (torch.Tensor): Segment coordinates in format (N, 2) where N is number of points.\n width (int): Width of the image in pixels.\n height (int): Height of the image in pixels.\n\n Returns:\n (np.ndarray): Bounding box coordinates in xyxy format [x1, y1, x2, y2].\n \"\"\"\n x, y = segment.T # segment xy\n # Clip coordinates if 3 out of 4 sides are outside the image\n if np.array([x.min() < 0, y.min() < 0, x.max() > width, y.max() > height]).sum() >= 3:\n x = x.clip(0, width)\n y = y.clip(0, height)\n inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)\n x = x[inside]\n y = y[inside]\n return (\n np.array([x.min(), y.min(), x.max(), y.max()], dtype=segment.dtype)\n if any(x)\n else np.zeros(4, dtype=segment.dtype)\n ) # xyxy", "chunk_type": "function", "name": "segment2box", "file_path": "ultralytics\\ultralytics\\utils\\ops.py", "start_line": 75, "end_line": 102, "start_col": 0, "end_col": 5, "parent_name": null, "docstring": "Convert segment coordinates to bounding box coordinates.\n\nConverts a single segment label to a box label by finding the minimum and maximum x and y coordinates.\nApplies inside-image constraint and clips coordinates when necessary.\n\nArgs:\n segment (torch.Tensor): Segment coordinates in format (N, 2) where N is number of points.\n width (int): Width of the image in pixels.\n height (int): Height of the image in pixels.\n\nReturns:\n (np.ndarray): Bounding box coordinates in xyxy format [x1, y1, x2, y2].", "parameters": [ "segment", "width: int", "height: int" ], "return_type": null, "decorators": [], "complexity_score": 2, "dependencies": [ "contextlib", "math", "re", "time", "typing.Optional", "cv2", "numpy", "torch", "torch.nn.functional", "ultralytics.utils.LOGGER", "ultralytics.utils.metrics.batch_probiou", "torchvision", "ultralytics.data.converter.merge_multi_segment" ], "chunk_id": "function_segment2box_3f4ab63a" }, { "content": "def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding: bool = True, xywh: bool = False):\n \"\"\"\n Rescale bounding boxes from one image shape to another.\n\n Rescales bounding boxes from img1_shape to img0_shape, accounting for padding and aspect ratio changes.\n Supports both xyxy and xywh box formats.\n\n Args:\n img1_shape (tuple): Shape of the source image (height, width).\n boxes (torch.Tensor): Bounding boxes to rescale in format (N, 4).\n img0_shape (tuple): Shape of the target image (height, width).\n ratio_pad (tuple, optional): Tuple of (ratio, pad) for scaling. If None, calculated from image shapes.\n padding (bool): Whether boxes are based on YOLO-style augmented images with padding.\n xywh (bool): Whether box format is xywh (True) or xyxy (False).\n\n Returns:\n (torch.Tensor): Rescaled bounding boxes in the same format as input.\n \"\"\"\n if ratio_pad is None: # calculate from img0_shape\n gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new\n pad = (\n round((img1_shape[1] - img0_shape[1] * gain) / 2 - 0.1),\n round((img1_shape[0] - img0_shape[0] * gain) / 2 - 0.1),\n ) # wh padding\n else:\n gain = ratio_pad[0][0]\n pad = ratio_pad[1]\n\n if padding:\n boxes[..., 0] -= pad[0] # x padding\n boxes[..., 1] -= pad[1] # y padding\n if not xywh:\n boxes[..., 2] -= pad[0] # x padding\n boxes[..., 3] -= pad[1] # y padding\n boxes[..., :4] /= gain\n return clip_boxes(boxes, img0_shape)", "chunk_type": "function", "name": "scale_boxes", "file_path": "ultralytics\\ultralytics\\utils\\ops.py", "start_line": 105, "end_line": 140, "start_col": 0, "end_col": 40, "parent_name": null, "docstring": "Rescale bounding boxes from one image shape to another.\n\nRescales bounding boxes from img1_shape to img0_shape, accounting for padding and aspect ratio changes.\nSupports both xyxy and xywh box formats.\n\nArgs:\n img1_shape (tuple): Shape of the source image (height, width).\n boxes (torch.Tensor): Bounding boxes to rescale in format (N, 4).\n img0_shape (tuple): Shape of the target image (height, width).\n ratio_pad (tuple, optional): Tuple of (ratio, pad) for scaling. If None, calculated from image shapes.\n padding (bool): Whether boxes are based on YOLO-style augmented images with padding.\n xywh (bool): Whether box format is xywh (True) or xyxy (False).\n\nReturns:\n (torch.Tensor): Rescaled bounding boxes in the same format as input.", "parameters": [ "img1_shape", "boxes", "img0_shape", "ratio_pad", "padding: bool", "xywh: bool" ], "return_type": null, "decorators": [], "complexity_score": 4, "dependencies": [ "contextlib", "math", "re", "time", "typing.Optional", "cv2", "numpy", "torch", "torch.nn.functional", "ultralytics.utils.LOGGER", "ultralytics.utils.metrics.batch_probiou", "torchvision", "ultralytics.data.converter.merge_multi_segment" ], "chunk_id": "function_scale_boxes_e90f89fd" }, { "content": "def make_divisible(x: int, divisor):\n \"\"\"\n Return the nearest number that is divisible by the given divisor.\n\n Args:\n x (int): The number to make divisible.\n divisor (int | torch.Tensor): The divisor.\n\n Returns:\n (int): The nearest number divisible by the divisor.\n \"\"\"\n if isinstance(divisor, torch.Tensor):\n divisor = int(divisor.max()) # to int\n return math.ceil(x / divisor) * divisor", "chunk_type": "function", "name": "make_divisible", "file_path": "ultralytics\\ultralytics\\utils\\ops.py", "start_line": 143, "end_line": 156, "start_col": 0, "end_col": 43, "parent_name": null, "docstring": "Return the nearest number that is divisible by the given divisor.\n\nArgs:\n x (int): The number to make divisible.\n divisor (int | torch.Tensor): The divisor.\n\nReturns:\n (int): The nearest number divisible by the divisor.", "parameters": [ "x: int", "divisor" ], "return_type": null, "decorators": [], "complexity_score": 2, "dependencies": [ "contextlib", "math", "re", "time", "typing.Optional", "cv2", "numpy", "torch", "torch.nn.functional", "ultralytics.utils.LOGGER", "ultralytics.utils.metrics.batch_probiou", "torchvision", "ultralytics.data.converter.merge_multi_segment" ], "chunk_id": "function_make_divisible_3cc87115" }, { "content": "def nms_rotated(boxes, scores, threshold: float = 0.45, use_triu: bool = True):\n \"\"\"\n Perform NMS on oriented bounding boxes using probiou and fast-nms.\n\n Args:\n boxes (torch.Tensor): Rotated bounding boxes with shape (N, 5) in xywhr format.\n scores (torch.Tensor): Confidence scores with shape (N,).\n threshold (float): IoU threshold for NMS.\n use_triu (bool): Whether to use torch.triu operator for upper triangular matrix operations.\n\n Returns:\n (torch.Tensor): Indices of boxes to keep after NMS.\n \"\"\"\n sorted_idx = torch.argsort(scores, descending=True)\n boxes = boxes[sorted_idx]\n ious = batch_probiou(boxes, boxes)\n if use_triu:\n ious = ious.triu_(diagonal=1)\n # NOTE: handle the case when len(boxes) hence exportable by eliminating if-else condition\n pick = torch.nonzero((ious >= threshold).sum(0) <= 0).squeeze_(-1)\n else:\n n = boxes.shape[0]\n row_idx = torch.arange(n, device=boxes.device).view(-1, 1).expand(-1, n)\n col_idx = torch.arange(n, device=boxes.device).view(1, -1).expand(n, -1)\n upper_mask = row_idx < col_idx\n ious = ious * upper_mask\n # Zeroing these scores ensures the additional indices would not affect the final results\n scores[~((ious >= threshold).sum(0) <= 0)] = 0\n # NOTE: return indices with fixed length to avoid TFLite reshape error\n pick = torch.topk(scores, scores.shape[0]).indices\n return sorted_idx[pick]", "chunk_type": "function", "name": "nms_rotated", "file_path": "ultralytics\\ultralytics\\utils\\ops.py", "start_line": 159, "end_line": 189, "start_col": 0, "end_col": 27, "parent_name": null, "docstring": "Perform NMS on oriented bounding boxes using probiou and fast-nms.\n\nArgs:\n boxes (torch.Tensor): Rotated bounding boxes with shape (N, 5) in xywhr format.\n scores (torch.Tensor): Confidence scores with shape (N,).\n threshold (float): IoU threshold for NMS.\n use_triu (bool): Whether to use torch.triu operator for upper triangular matrix operations.\n\nReturns:\n (torch.Tensor): Indices of boxes to keep after NMS.", "parameters": [ "boxes", "scores", "threshold: float", "use_triu: bool" ], "return_type": null, "decorators": [], "complexity_score": 2, "dependencies": [ "contextlib", "math", "re", "time", "typing.Optional", "cv2", "numpy", "torch", "torch.nn.functional", "ultralytics.utils.LOGGER", "ultralytics.utils.metrics.batch_probiou", "torchvision", "ultralytics.data.converter.merge_multi_segment" ], "chunk_id": "function_nms_rotated_4d120576" }, { "content": "def non_max_suppression(\n prediction,\n conf_thres: float = 0.25,\n iou_thres: float = 0.45,\n classes=None,\n agnostic: bool = False,\n multi_label: bool = False,\n labels=(),\n max_det: int = 300,\n nc: int = 0, # number of classes (optional)\n max_time_img: float = 0.05,\n max_nms: int = 30000,\n max_wh: int = 7680,\n in_place: bool = True,\n rotated: bool = False,\n end2end: bool = False,\n return_idxs: bool = False,\n):\n \"\"\"\n Perform non-maximum suppression (NMS) on prediction results.\n\n Applies NMS to filter overlapping bounding boxes based on confidence and IoU thresholds. Supports multiple\n detection formats including standard boxes, rotated boxes, and masks.\n\n Args:\n prediction (torch.Tensor): Predictions with shape (batch_size, num_classes + 4 + num_masks, num_boxes)\n containing boxes, classes, and optional masks.\n conf_thres (float): Confidence threshold for filtering detections. Valid values are between 0.0 and 1.0.\n iou_thres (float): IoU threshold for NMS filtering. Valid values are between 0.0 and 1.0.\n classes (List[int], optional): List of class indices to consider. If None, all classes are considered.\n agnostic (bool): Whether to perform class-agnostic NMS.\n multi_label (bool): Whether each box can have multiple labels.\n labels (List[List[Union[int, float, torch.Tensor]]]): A priori labels for each image.\n max_det (int): Maximum number of detections to keep per image.\n nc (int): Number of classes. Indices after this are considered masks.\n max_time_img (float): Maximum time in seconds for processing one image.\n max_nms (int): Maximum number of boxes for torchvision.ops.nms().\n max_wh (int): Maximum box width and height in pixels.\n in_place (bool): Whether to modify the input prediction tensor in place.\n rotated (bool): Whether to handle Oriented Bounding Boxes (OBB).\n end2end (bool): Whether the model is end-to-end and doesn't require NMS.\n return_idxs (bool): Whether to return the indices of kept detections.\n\n Returns:\n output (List[torch.Tensor]): List of detections per image with shape (num_boxes, 6 + num_masks)\n containing (x1, y1, x2, y2, confidence, class, mask1, mask2, ...).\n keepi (List[torch.Tensor]): Indices of kept detections if return_idxs=True.\n \"\"\"\n import torchvision # scope for faster 'import ultralytics'\n\n # Checks\n assert 0 <= conf_thres <= 1, f\"Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0\"\n assert 0 <= iou_thres <= 1, f\"Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0\"\n if isinstance(prediction, (list, tuple)): # YOLOv8 model in validation model, output = (inference_out, loss_out)\n prediction = prediction[0] # select only inference output\n if classes is not None:\n classes = torch.tensor(classes, device=prediction.device)\n\n if prediction.shape[-1] == 6 or end2end: # end-to-end model (BNC, i.e. 1,300,6)\n output = [pred[pred[:, 4] > conf_thres][:max_det] for pred in prediction]\n if classes is not None:\n output = [pred[(pred[:, 5:6] == classes).any(1)] for pred in output]\n return output\n\n bs = prediction.shape[0] # batch size (BCN, i.e. 1,84,6300)\n nc = nc or (prediction.shape[1] - 4) # number of classes\n extra = prediction.shape[1] - nc - 4 # number of extra info\n mi = 4 + nc # mask start index\n xc = prediction[:, 4:mi].amax(1) > conf_thres # candidates\n xinds = torch.stack([torch.arange(len(i), device=prediction.device) for i in xc])[..., None] # to track idxs\n\n # Settings\n # min_wh = 2 # (pixels) minimum box width and height\n time_limit = 2.0 + max_time_img * bs # seconds to quit after\n multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)\n\n prediction = prediction.transpose(-1, -2) # shape(1,84,6300) to shape(1,6300,84)\n if not rotated:\n if in_place:\n prediction[..., :4] = xywh2xyxy(prediction[..., :4]) # xywh to xyxy\n else:\n prediction = torch.cat((xywh2xyxy(prediction[..., :4]), prediction[..., 4:]), dim=-1) # xywh to xyxy\n\n t = time.time()\n output = [torch.zeros((0, 6 + extra), device=prediction.device)] * bs\n keepi = [torch.zeros((0, 1), device=prediction.device)] * bs # to store the kept idxs\n for xi, (x, xk) in enumerate(zip(prediction, xinds)): # image index, (preds, preds indices)\n # Apply constraints\n # x[((x[:, 2:4] < min_wh) | (x[:, 2:4] > max_wh)).any(1), 4] = 0 # width-height\n filt = xc[xi] # confidence\n x, xk = x[filt], xk[filt]\n\n # Cat apriori labels if autolabelling\n if labels and len(labels[xi]) and not rotated:\n lb = labels[xi]\n v = torch.zeros((len(lb), nc + extra + 4), device=x.device)\n v[:, :4] = xywh2xyxy(lb[:, 1:5]) # box\n v[range(len(lb)), lb[:, 0].long() + 4] = 1.0 # cls\n x = torch.cat((x, v), 0)\n\n # If none remain process next image\n if not x.shape[0]:\n continue\n\n # Detections matrix nx6 (xyxy, conf, cls)\n box, cls, mask = x.split((4, nc, extra), 1)\n\n if multi_label:\n i, j = torch.where(cls > conf_thres)\n x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float(), mask[i]), 1)\n xk = xk[i]\n else: # best class only\n conf, j = cls.max(1, keepdim=True)\n filt = conf.view(-1) > conf_thres\n x = torch.cat((box, conf, j.float(), mask), 1)[filt]\n xk = xk[filt]\n\n # Filter by class\n if classes is not None:\n filt = (x[:, 5:6] == classes).any(1)\n x, xk = x[filt], xk[filt]\n\n # Check shape\n n = x.shape[0] # number of boxes\n if not n: # no boxes\n continue\n if n > max_nms: # excess boxes\n filt = x[:, 4].argsort(descending=True)[:max_nms] # sort by confidence and remove excess boxes\n x, xk = x[filt], xk[filt]\n\n # Batched NMS\n c = x[:, 5:6] * (0 if agnostic else max_wh) # classes\n scores = x[:, 4] # scores\n if rotated:\n boxes = torch.cat((x[:, :2] + c, x[:, 2:4], x[:, -1:]), dim=-1) # xywhr\n i = nms_rotated(boxes, scores, iou_thres)\n else:\n boxes = x[:, :4] + c # boxes (offset by class)\n i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS\n i = i[:max_det] # limit detections\n\n output[xi], keepi[xi] = x[i], xk[i].reshape(-1)\n if (time.time() - t) > time_limit:\n LOGGER.warning(f\"NMS time limit {time_limit:.3f}s exceeded\")\n break # time limit exceeded\n\n return (output, keepi) if return_idxs else output", "chunk_type": "function", "name": "non_max_suppression", "file_path": "ultralytics\\ultralytics\\utils\\ops.py", "start_line": 192, "end_line": 338, "start_col": 0, "end_col": 53, "parent_name": null, "docstring": "Perform non-maximum suppression (NMS) on prediction results.\n\nApplies NMS to filter overlapping bounding boxes based on confidence and IoU thresholds. Supports multiple\ndetection formats including standard boxes, rotated boxes, and masks.\n\nArgs:\n prediction (torch.Tensor): Predictions with shape (batch_size, num_classes + 4 + num_masks, num_boxes)\n containing boxes, classes, and optional masks.\n conf_thres (float): Confidence threshold for filtering detections. Valid values are between 0.0 and 1.0.\n iou_thres (float): IoU threshold for NMS filtering. Valid values are between 0.0 and 1.0.\n classes (List[int], optional): List of class indices to consider. If None, all classes are considered.\n agnostic (bool): Whether to perform class-agnostic NMS.\n multi_label (bool): Whether each box can have multiple labels.\n labels (List[List[Union[int, float, torch.Tensor]]]): A priori labels for each image.\n max_det (int): Maximum number of detections to keep per image.\n nc (int): Number of classes. Indices after this are considered masks.\n max_time_img (float): Maximum time in seconds for processing one image.\n max_nms (int): Maximum number of boxes for torchvision.ops.nms().\n max_wh (int): Maximum box width and height in pixels.\n in_place (bool): Whether to modify the input prediction tensor in place.\n rotated (bool): Whether to handle Oriented Bounding Boxes (OBB).\n end2end (bool): Whether the model is end-to-end and doesn't require NMS.\n return_idxs (bool): Whether to return the indices of kept detections.\n\nReturns:\n output (List[torch.Tensor]): List of detections per image with shape (num_boxes, 6 + num_masks)\n containing (x1, y1, x2, y2, confidence, class, mask1, mask2, ...).\n keepi (List[torch.Tensor]): Indices of kept detections if return_idxs=True.", "parameters": [ "prediction", "conf_thres: float", "iou_thres: float", "classes", "agnostic: bool", "multi_label: bool", "labels", "max_det: int", "nc: int", "max_time_img: float", "max_nms: int", "max_wh: int", "in_place: bool", "rotated: bool", "end2end: bool", "return_idxs: bool" ], "return_type": null, "decorators": [], "complexity_score": 19, "dependencies": [ "contextlib", "math", "re", "time", "typing.Optional", "cv2", "numpy", "torch", "torch.nn.functional", "ultralytics.utils.LOGGER", "ultralytics.utils.metrics.batch_probiou", "torchvision", "ultralytics.data.converter.merge_multi_segment" ], "chunk_id": "function_non_max_suppression_2bb4cfbb" }, { "content": "def clip_boxes(boxes, shape):\n \"\"\"\n Clip bounding boxes to image boundaries.\n\n Args:\n boxes (torch.Tensor | np.ndarray): Bounding boxes to clip.\n shape (tuple): Image shape as (height, width).\n\n Returns:\n (torch.Tensor | np.ndarray): Clipped bounding boxes.\n \"\"\"\n if isinstance(boxes, torch.Tensor): # faster individually (WARNING: inplace .clamp_() Apple MPS bug)\n boxes[..., 0] = boxes[..., 0].clamp(0, shape[1]) # x1\n boxes[..., 1] = boxes[..., 1].clamp(0, shape[0]) # y1\n boxes[..., 2] = boxes[..., 2].clamp(0, shape[1]) # x2\n boxes[..., 3] = boxes[..., 3].clamp(0, shape[0]) # y2\n else: # np.array (faster grouped)\n boxes[..., [0, 2]] = boxes[..., [0, 2]].clip(0, shape[1]) # x1, x2\n boxes[..., [1, 3]] = boxes[..., [1, 3]].clip(0, shape[0]) # y1, y2\n return boxes", "chunk_type": "function", "name": "clip_boxes", "file_path": "ultralytics\\ultralytics\\utils\\ops.py", "start_line": 341, "end_line": 360, "start_col": 0, "end_col": 16, "parent_name": null, "docstring": "Clip bounding boxes to image boundaries.\n\nArgs:\n boxes (torch.Tensor | np.ndarray): Bounding boxes to clip.\n shape (tuple): Image shape as (height, width).\n\nReturns:\n (torch.Tensor | np.ndarray): Clipped bounding boxes.", "parameters": [ "boxes", "shape" ], "return_type": null, "decorators": [], "complexity_score": 2, "dependencies": [ "contextlib", "math", "re", "time", "typing.Optional", "cv2", "numpy", "torch", "torch.nn.functional", "ultralytics.utils.LOGGER", "ultralytics.utils.metrics.batch_probiou", "torchvision", "ultralytics.data.converter.merge_multi_segment" ], "chunk_id": "function_clip_boxes_7ca129ed" }, { "content": "def clip_coords(coords, shape):\n \"\"\"\n Clip line coordinates to image boundaries.\n\n Args:\n coords (torch.Tensor | np.ndarray): Line coordinates to clip.\n shape (tuple): Image shape as (height, width).\n\n Returns:\n (torch.Tensor | np.ndarray): Clipped coordinates.\n \"\"\"\n if isinstance(coords, torch.Tensor): # faster individually (WARNING: inplace .clamp_() Apple MPS bug)\n coords[..., 0] = coords[..., 0].clamp(0, shape[1]) # x\n coords[..., 1] = coords[..., 1].clamp(0, shape[0]) # y\n else: # np.array (faster grouped)\n coords[..., 0] = coords[..., 0].clip(0, shape[1]) # x\n coords[..., 1] = coords[..., 1].clip(0, shape[0]) # y\n return coords", "chunk_type": "function", "name": "clip_coords", "file_path": "ultralytics\\ultralytics\\utils\\ops.py", "start_line": 363, "end_line": 380, "start_col": 0, "end_col": 17, "parent_name": null, "docstring": "Clip line coordinates to image boundaries.\n\nArgs:\n coords (torch.Tensor | np.ndarray): Line coordinates to clip.\n shape (tuple): Image shape as (height, width).\n\nReturns:\n (torch.Tensor | np.ndarray): Clipped coordinates.", "parameters": [ "coords", "shape" ], "return_type": null, "decorators": [], "complexity_score": 2, "dependencies": [ "contextlib", "math", "re", "time", "typing.Optional", "cv2", "numpy", "torch", "torch.nn.functional", "ultralytics.utils.LOGGER", "ultralytics.utils.metrics.batch_probiou", "torchvision", "ultralytics.data.converter.merge_multi_segment" ], "chunk_id": "function_clip_coords_a9c40ad7" }, { "content": "def scale_image(masks, im0_shape, ratio_pad=None):\n \"\"\"\n Rescale masks to original image size.\n\n Takes resized and padded masks and rescales them back to the original image dimensions, removing any padding\n that was applied during preprocessing.\n\n Args:\n masks (np.ndarray): Resized and padded masks with shape [H, W, N] or [H, W, 3].\n im0_shape (tuple): Original image shape as (height, width).\n ratio_pad (tuple, optional): Ratio and padding values as ((ratio_h, ratio_w), (pad_h, pad_w)).\n\n Returns:\n (np.ndarray): Rescaled masks with shape [H, W, N] matching original image dimensions.\n \"\"\"\n # Rescale coordinates (xyxy) from im1_shape to im0_shape\n im1_shape = masks.shape\n if im1_shape[:2] == im0_shape[:2]:\n return masks\n if ratio_pad is None: # calculate from im0_shape\n gain = min(im1_shape[0] / im0_shape[0], im1_shape[1] / im0_shape[1]) # gain = old / new\n pad = (im1_shape[1] - im0_shape[1] * gain) / 2, (im1_shape[0] - im0_shape[0] * gain) / 2 # wh padding\n else:\n pad = ratio_pad[1]\n\n top, left = (int(round(pad[1] - 0.1)), int(round(pad[0] - 0.1)))\n bottom, right = (\n im1_shape[0] - int(round(pad[1] + 0.1)),\n im1_shape[1] - int(round(pad[0] + 0.1)),\n )\n\n if len(masks.shape) < 2:\n raise ValueError(f'\"len of masks shape\" should be 2 or 3, but got {len(masks.shape)}')\n masks = masks[top:bottom, left:right]\n masks = cv2.resize(masks, (im0_shape[1], im0_shape[0]))\n if len(masks.shape) == 2:\n masks = masks[:, :, None]\n\n return masks", "chunk_type": "function", "name": "scale_image", "file_path": "ultralytics\\ultralytics\\utils\\ops.py", "start_line": 383, "end_line": 421, "start_col": 0, "end_col": 16, "parent_name": null, "docstring": "Rescale masks to original image size.\n\nTakes resized and padded masks and rescales them back to the original image dimensions, removing any padding\nthat was applied during preprocessing.\n\nArgs:\n masks (np.ndarray): Resized and padded masks with shape [H, W, N] or [H, W, 3].\n im0_shape (tuple): Original image shape as (height, width).\n ratio_pad (tuple, optional): Ratio and padding values as ((ratio_h, ratio_w), (pad_h, pad_w)).\n\nReturns:\n (np.ndarray): Rescaled masks with shape [H, W, N] matching original image dimensions.", "parameters": [ "masks", "im0_shape", "ratio_pad" ], "return_type": null, "decorators": [], "complexity_score": 5, "dependencies": [ "contextlib", "math", "re", "time", "typing.Optional", "cv2", "numpy", "torch", "torch.nn.functional", "ultralytics.utils.LOGGER", "ultralytics.utils.metrics.batch_probiou", "torchvision", "ultralytics.data.converter.merge_multi_segment" ], "chunk_id": "function_scale_image_d3e50193" }, { "content": "def xyxy2xywh(x):\n \"\"\"\n Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height) format where (x1, y1) is the\n top-left corner and (x2, y2) is the bottom-right corner.\n\n Args:\n x (np.ndarray | torch.Tensor): Input bounding box coordinates in (x1, y1, x2, y2) format.\n\n Returns:\n (np.ndarray | torch.Tensor): Bounding box coordinates in (x, y, width, height) format.\n \"\"\"\n assert x.shape[-1] == 4, f\"input shape last dimension expected 4 but input shape is {x.shape}\"\n y = empty_like(x) # faster than clone/copy\n y[..., 0] = (x[..., 0] + x[..., 2]) / 2 # x center\n y[..., 1] = (x[..., 1] + x[..., 3]) / 2 # y center\n y[..., 2] = x[..., 2] - x[..., 0] # width\n y[..., 3] = x[..., 3] - x[..., 1] # height\n return y", "chunk_type": "function", "name": "xyxy2xywh", "file_path": "ultralytics\\ultralytics\\utils\\ops.py", "start_line": 424, "end_line": 441, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": "Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height) format where (x1, y1) is the\ntop-left corner and (x2, y2) is the bottom-right corner.\n\nArgs:\n x (np.ndarray | torch.Tensor): Input bounding box coordinates in (x1, y1, x2, y2) format.\n\nReturns:\n (np.ndarray | torch.Tensor): Bounding box coordinates in (x, y, width, height) format.", "parameters": [ "x" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "contextlib", "math", "re", "time", "typing.Optional", "cv2", "numpy", "torch", "torch.nn.functional", "ultralytics.utils.LOGGER", "ultralytics.utils.metrics.batch_probiou", "torchvision", "ultralytics.data.converter.merge_multi_segment" ], "chunk_id": "function_xyxy2xywh_011f3ef6" }, { "content": "def xywh2xyxy(x):\n \"\"\"\n Convert bounding box coordinates from (x, y, width, height) format to (x1, y1, x2, y2) format where (x1, y1) is the\n top-left corner and (x2, y2) is the bottom-right corner. Note: ops per 2 channels faster than per channel.\n\n Args:\n x (np.ndarray | torch.Tensor): Input bounding box coordinates in (x, y, width, height) format.\n\n Returns:\n (np.ndarray | torch.Tensor): Bounding box coordinates in (x1, y1, x2, y2) format.\n \"\"\"\n assert x.shape[-1] == 4, f\"input shape last dimension expected 4 but input shape is {x.shape}\"\n y = empty_like(x) # faster than clone/copy\n xy = x[..., :2] # centers\n wh = x[..., 2:] / 2 # half width-height\n y[..., :2] = xy - wh # top left xy\n y[..., 2:] = xy + wh # bottom right xy\n return y", "chunk_type": "function", "name": "xywh2xyxy", "file_path": "ultralytics\\ultralytics\\utils\\ops.py", "start_line": 444, "end_line": 461, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": "Convert bounding box coordinates from (x, y, width, height) format to (x1, y1, x2, y2) format where (x1, y1) is the\ntop-left corner and (x2, y2) is the bottom-right corner. Note: ops per 2 channels faster than per channel.\n\nArgs:\n x (np.ndarray | torch.Tensor): Input bounding box coordinates in (x, y, width, height) format.\n\nReturns:\n (np.ndarray | torch.Tensor): Bounding box coordinates in (x1, y1, x2, y2) format.", "parameters": [ "x" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "contextlib", "math", "re", "time", "typing.Optional", "cv2", "numpy", "torch", "torch.nn.functional", "ultralytics.utils.LOGGER", "ultralytics.utils.metrics.batch_probiou", "torchvision", "ultralytics.data.converter.merge_multi_segment" ], "chunk_id": "function_xywh2xyxy_e2f5b2e0" }, { "content": "def xywhn2xyxy(x, w: int = 640, h: int = 640, padw: int = 0, padh: int = 0):\n \"\"\"\n Convert normalized bounding box coordinates to pixel coordinates.\n\n Args:\n x (np.ndarray | torch.Tensor): Normalized bounding box coordinates in (x, y, w, h) format.\n w (int): Image width in pixels.\n h (int): Image height in pixels.\n padw (int): Padding width in pixels.\n padh (int): Padding height in pixels.\n\n Returns:\n y (np.ndarray | torch.Tensor): The coordinates of the bounding box in the format [x1, y1, x2, y2] where\n x1,y1 is the top-left corner, x2,y2 is the bottom-right corner of the bounding box.\n \"\"\"\n assert x.shape[-1] == 4, f\"input shape last dimension expected 4 but input shape is {x.shape}\"\n y = empty_like(x) # faster than clone/copy\n y[..., 0] = w * (x[..., 0] - x[..., 2] / 2) + padw # top left x\n y[..., 1] = h * (x[..., 1] - x[..., 3] / 2) + padh # top left y\n y[..., 2] = w * (x[..., 0] + x[..., 2] / 2) + padw # bottom right x\n y[..., 3] = h * (x[..., 1] + x[..., 3] / 2) + padh # bottom right y\n return y", "chunk_type": "function", "name": "xywhn2xyxy", "file_path": "ultralytics\\ultralytics\\utils\\ops.py", "start_line": 464, "end_line": 485, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": "Convert normalized bounding box coordinates to pixel coordinates.\n\nArgs:\n x (np.ndarray | torch.Tensor): Normalized bounding box coordinates in (x, y, w, h) format.\n w (int): Image width in pixels.\n h (int): Image height in pixels.\n padw (int): Padding width in pixels.\n padh (int): Padding height in pixels.\n\nReturns:\n y (np.ndarray | torch.Tensor): The coordinates of the bounding box in the format [x1, y1, x2, y2] where\n x1,y1 is the top-left corner, x2,y2 is the bottom-right corner of the bounding box.", "parameters": [ "x", "w: int", "h: int", "padw: int", "padh: int" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "contextlib", "math", "re", "time", "typing.Optional", "cv2", "numpy", "torch", "torch.nn.functional", "ultralytics.utils.LOGGER", "ultralytics.utils.metrics.batch_probiou", "torchvision", "ultralytics.data.converter.merge_multi_segment" ], "chunk_id": "function_xywhn2xyxy_00ae5793" }, { "content": "def xyxy2xywhn(x, w: int = 640, h: int = 640, clip: bool = False, eps: float = 0.0):\n \"\"\"\n Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height, normalized) format. x, y,\n width and height are normalized to image dimensions.\n\n Args:\n x (np.ndarray | torch.Tensor): Input bounding box coordinates in (x1, y1, x2, y2) format.\n w (int): Image width in pixels.\n h (int): Image height in pixels.\n clip (bool): Whether to clip boxes to image boundaries.\n eps (float): Minimum value for box width and height.\n\n Returns:\n (np.ndarray | torch.Tensor): Normalized bounding box coordinates in (x, y, width, height) format.\n \"\"\"\n if clip:\n x = clip_boxes(x, (h - eps, w - eps))\n assert x.shape[-1] == 4, f\"input shape last dimension expected 4 but input shape is {x.shape}\"\n y = empty_like(x) # faster than clone/copy\n y[..., 0] = ((x[..., 0] + x[..., 2]) / 2) / w # x center\n y[..., 1] = ((x[..., 1] + x[..., 3]) / 2) / h # y center\n y[..., 2] = (x[..., 2] - x[..., 0]) / w # width\n y[..., 3] = (x[..., 3] - x[..., 1]) / h # height\n return y", "chunk_type": "function", "name": "xyxy2xywhn", "file_path": "ultralytics\\ultralytics\\utils\\ops.py", "start_line": 488, "end_line": 511, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": "Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height, normalized) format. x, y,\nwidth and height are normalized to image dimensions.\n\nArgs:\n x (np.ndarray | torch.Tensor): Input bounding box coordinates in (x1, y1, x2, y2) format.\n w (int): Image width in pixels.\n h (int): Image height in pixels.\n clip (bool): Whether to clip boxes to image boundaries.\n eps (float): Minimum value for box width and height.\n\nReturns:\n (np.ndarray | torch.Tensor): Normalized bounding box coordinates in (x, y, width, height) format.", "parameters": [ "x", "w: int", "h: int", "clip: bool", "eps: float" ], "return_type": null, "decorators": [], "complexity_score": 2, "dependencies": [ "contextlib", "math", "re", "time", "typing.Optional", "cv2", "numpy", "torch", "torch.nn.functional", "ultralytics.utils.LOGGER", "ultralytics.utils.metrics.batch_probiou", "torchvision", "ultralytics.data.converter.merge_multi_segment" ], "chunk_id": "function_xyxy2xywhn_bb3f38d9" }, { "content": "def xywh2ltwh(x):\n \"\"\"\n Convert bounding box format from [x, y, w, h] to [x1, y1, w, h] where x1, y1 are top-left coordinates.\n\n Args:\n x (np.ndarray | torch.Tensor): Input bounding box coordinates in xywh format.\n\n Returns:\n (np.ndarray | torch.Tensor): Bounding box coordinates in xyltwh format.\n \"\"\"\n y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)\n y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x\n y[..., 1] = x[..., 1] - x[..., 3] / 2 # top left y\n return y", "chunk_type": "function", "name": "xywh2ltwh", "file_path": "ultralytics\\ultralytics\\utils\\ops.py", "start_line": 514, "end_line": 527, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": "Convert bounding box format from [x, y, w, h] to [x1, y1, w, h] where x1, y1 are top-left coordinates.\n\nArgs:\n x (np.ndarray | torch.Tensor): Input bounding box coordinates in xywh format.\n\nReturns:\n (np.ndarray | torch.Tensor): Bounding box coordinates in xyltwh format.", "parameters": [ "x" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "contextlib", "math", "re", "time", "typing.Optional", "cv2", "numpy", "torch", "torch.nn.functional", "ultralytics.utils.LOGGER", "ultralytics.utils.metrics.batch_probiou", "torchvision", "ultralytics.data.converter.merge_multi_segment" ], "chunk_id": "function_xywh2ltwh_dd88b01f" }, { "content": "def xyxy2ltwh(x):\n \"\"\"\n Convert bounding boxes from [x1, y1, x2, y2] to [x1, y1, w, h] format.\n\n Args:\n x (np.ndarray | torch.Tensor): Input bounding box coordinates in xyxy format.\n\n Returns:\n (np.ndarray | torch.Tensor): Bounding box coordinates in xyltwh format.\n \"\"\"\n y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)\n y[..., 2] = x[..., 2] - x[..., 0] # width\n y[..., 3] = x[..., 3] - x[..., 1] # height\n return y", "chunk_type": "function", "name": "xyxy2ltwh", "file_path": "ultralytics\\ultralytics\\utils\\ops.py", "start_line": 530, "end_line": 543, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": "Convert bounding boxes from [x1, y1, x2, y2] to [x1, y1, w, h] format.\n\nArgs:\n x (np.ndarray | torch.Tensor): Input bounding box coordinates in xyxy format.\n\nReturns:\n (np.ndarray | torch.Tensor): Bounding box coordinates in xyltwh format.", "parameters": [ "x" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "contextlib", "math", "re", "time", "typing.Optional", "cv2", "numpy", "torch", "torch.nn.functional", "ultralytics.utils.LOGGER", "ultralytics.utils.metrics.batch_probiou", "torchvision", "ultralytics.data.converter.merge_multi_segment" ], "chunk_id": "function_xyxy2ltwh_d71804d8" }, { "content": "def ltwh2xywh(x):\n \"\"\"\n Convert bounding boxes from [x1, y1, w, h] to [x, y, w, h] where xy1=top-left, xy=center.\n\n Args:\n x (torch.Tensor): Input bounding box coordinates.\n\n Returns:\n (np.ndarray | torch.Tensor): Bounding box coordinates in xywh format.\n \"\"\"\n y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)\n y[..., 0] = x[..., 0] + x[..., 2] / 2 # center x\n y[..., 1] = x[..., 1] + x[..., 3] / 2 # center y\n return y", "chunk_type": "function", "name": "ltwh2xywh", "file_path": "ultralytics\\ultralytics\\utils\\ops.py", "start_line": 546, "end_line": 559, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": "Convert bounding boxes from [x1, y1, w, h] to [x, y, w, h] where xy1=top-left, xy=center.\n\nArgs:\n x (torch.Tensor): Input bounding box coordinates.\n\nReturns:\n (np.ndarray | torch.Tensor): Bounding box coordinates in xywh format.", "parameters": [ "x" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "contextlib", "math", "re", "time", "typing.Optional", "cv2", "numpy", "torch", "torch.nn.functional", "ultralytics.utils.LOGGER", "ultralytics.utils.metrics.batch_probiou", "torchvision", "ultralytics.data.converter.merge_multi_segment" ], "chunk_id": "function_ltwh2xywh_bd19410c" }, { "content": "def xyxyxyxy2xywhr(x):\n \"\"\"\n Convert batched Oriented Bounding Boxes (OBB) from [xy1, xy2, xy3, xy4] to [xywh, rotation] format.\n\n Args:\n x (np.ndarray | torch.Tensor): Input box corners with shape (N, 8) in [xy1, xy2, xy3, xy4] format.\n\n Returns:\n (np.ndarray | torch.Tensor): Converted data in [cx, cy, w, h, rotation] format with shape (N, 5).\n Rotation values are in radians from 0 to pi/2.\n \"\"\"\n is_torch = isinstance(x, torch.Tensor)\n points = x.cpu().numpy() if is_torch else x\n points = points.reshape(len(x), -1, 2)\n rboxes = []\n for pts in points:\n # NOTE: Use cv2.minAreaRect to get accurate xywhr,\n # especially some objects are cut off by augmentations in dataloader.\n (cx, cy), (w, h), angle = cv2.minAreaRect(pts)\n rboxes.append([cx, cy, w, h, angle / 180 * np.pi])\n return torch.tensor(rboxes, device=x.device, dtype=x.dtype) if is_torch else np.asarray(rboxes)", "chunk_type": "function", "name": "xyxyxyxy2xywhr", "file_path": "ultralytics\\ultralytics\\utils\\ops.py", "start_line": 562, "end_line": 582, "start_col": 0, "end_col": 99, "parent_name": null, "docstring": "Convert batched Oriented Bounding Boxes (OBB) from [xy1, xy2, xy3, xy4] to [xywh, rotation] format.\n\nArgs:\n x (np.ndarray | torch.Tensor): Input box corners with shape (N, 8) in [xy1, xy2, xy3, xy4] format.\n\nReturns:\n (np.ndarray | torch.Tensor): Converted data in [cx, cy, w, h, rotation] format with shape (N, 5).\n Rotation values are in radians from 0 to pi/2.", "parameters": [ "x" ], "return_type": null, "decorators": [], "complexity_score": 2, "dependencies": [ "contextlib", "math", "re", "time", "typing.Optional", "cv2", "numpy", "torch", "torch.nn.functional", "ultralytics.utils.LOGGER", "ultralytics.utils.metrics.batch_probiou", "torchvision", "ultralytics.data.converter.merge_multi_segment" ], "chunk_id": "function_xyxyxyxy2xywhr_a60266e7" }, { "content": "def xywhr2xyxyxyxy(x):\n \"\"\"\n Convert batched Oriented Bounding Boxes (OBB) from [xywh, rotation] to [xy1, xy2, xy3, xy4] format.\n\n Args:\n x (np.ndarray | torch.Tensor): Boxes in [cx, cy, w, h, rotation] format with shape (N, 5) or (B, N, 5).\n Rotation values should be in radians from 0 to pi/2.\n\n Returns:\n (np.ndarray | torch.Tensor): Converted corner points with shape (N, 4, 2) or (B, N, 4, 2).\n \"\"\"\n cos, sin, cat, stack = (\n (torch.cos, torch.sin, torch.cat, torch.stack)\n if isinstance(x, torch.Tensor)\n else (np.cos, np.sin, np.concatenate, np.stack)\n )\n\n ctr = x[..., :2]\n w, h, angle = (x[..., i : i + 1] for i in range(2, 5))\n cos_value, sin_value = cos(angle), sin(angle)\n vec1 = [w / 2 * cos_value, w / 2 * sin_value]\n vec2 = [-h / 2 * sin_value, h / 2 * cos_value]\n vec1 = cat(vec1, -1)\n vec2 = cat(vec2, -1)\n pt1 = ctr + vec1 + vec2\n pt2 = ctr + vec1 - vec2\n pt3 = ctr - vec1 - vec2\n pt4 = ctr - vec1 + vec2\n return stack([pt1, pt2, pt3, pt4], -2)", "chunk_type": "function", "name": "xywhr2xyxyxyxy", "file_path": "ultralytics\\ultralytics\\utils\\ops.py", "start_line": 585, "end_line": 613, "start_col": 0, "end_col": 42, "parent_name": null, "docstring": "Convert batched Oriented Bounding Boxes (OBB) from [xywh, rotation] to [xy1, xy2, xy3, xy4] format.\n\nArgs:\n x (np.ndarray | torch.Tensor): Boxes in [cx, cy, w, h, rotation] format with shape (N, 5) or (B, N, 5).\n Rotation values should be in radians from 0 to pi/2.\n\nReturns:\n (np.ndarray | torch.Tensor): Converted corner points with shape (N, 4, 2) or (B, N, 4, 2).", "parameters": [ "x" ], "return_type": null, "decorators": [], "complexity_score": 2, "dependencies": [ "contextlib", "math", "re", "time", "typing.Optional", "cv2", "numpy", "torch", "torch.nn.functional", "ultralytics.utils.LOGGER", "ultralytics.utils.metrics.batch_probiou", "torchvision", "ultralytics.data.converter.merge_multi_segment" ], "chunk_id": "function_xywhr2xyxyxyxy_eb8258ba" }, { "content": "def ltwh2xyxy(x):\n \"\"\"\n Convert bounding box from [x1, y1, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right.\n\n Args:\n x (np.ndarray | torch.Tensor): Input bounding box coordinates.\n\n Returns:\n (np.ndarray | torch.Tensor): Bounding box coordinates in xyxy format.\n \"\"\"\n y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)\n y[..., 2] = x[..., 2] + x[..., 0] # width\n y[..., 3] = x[..., 3] + x[..., 1] # height\n return y", "chunk_type": "function", "name": "ltwh2xyxy", "file_path": "ultralytics\\ultralytics\\utils\\ops.py", "start_line": 616, "end_line": 629, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": "Convert bounding box from [x1, y1, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right.\n\nArgs:\n x (np.ndarray | torch.Tensor): Input bounding box coordinates.\n\nReturns:\n (np.ndarray | torch.Tensor): Bounding box coordinates in xyxy format.", "parameters": [ "x" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "contextlib", "math", "re", "time", "typing.Optional", "cv2", "numpy", "torch", "torch.nn.functional", "ultralytics.utils.LOGGER", "ultralytics.utils.metrics.batch_probiou", "torchvision", "ultralytics.data.converter.merge_multi_segment" ], "chunk_id": "function_ltwh2xyxy_31ad8c0b" }, { "content": "def segments2boxes(segments):\n \"\"\"\n Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh).\n\n Args:\n segments (list): List of segments where each segment is a list of points, each point is [x, y] coordinates.\n\n Returns:\n (np.ndarray): Bounding box coordinates in xywh format.\n \"\"\"\n boxes = []\n for s in segments:\n x, y = s.T # segment xy\n boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy\n return xyxy2xywh(np.array(boxes)) # cls, xywh", "chunk_type": "function", "name": "segments2boxes", "file_path": "ultralytics\\ultralytics\\utils\\ops.py", "start_line": 632, "end_line": 646, "start_col": 0, "end_col": 37, "parent_name": null, "docstring": "Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh).\n\nArgs:\n segments (list): List of segments where each segment is a list of points, each point is [x, y] coordinates.\n\nReturns:\n (np.ndarray): Bounding box coordinates in xywh format.", "parameters": [ "segments" ], "return_type": null, "decorators": [], "complexity_score": 2, "dependencies": [ "contextlib", "math", "re", "time", "typing.Optional", "cv2", "numpy", "torch", "torch.nn.functional", "ultralytics.utils.LOGGER", "ultralytics.utils.metrics.batch_probiou", "torchvision", "ultralytics.data.converter.merge_multi_segment" ], "chunk_id": "function_segments2boxes_3e6c5756" }, { "content": "def resample_segments(segments, n: int = 1000):\n \"\"\"\n Resample segments to n points each using linear interpolation.\n\n Args:\n segments (list): List of (N, 2) arrays where N is the number of points in each segment.\n n (int): Number of points to resample each segment to.\n\n Returns:\n (list): Resampled segments with n points each.\n \"\"\"\n for i, s in enumerate(segments):\n if len(s) == n:\n continue\n s = np.concatenate((s, s[0:1, :]), axis=0)\n x = np.linspace(0, len(s) - 1, n - len(s) if len(s) < n else n)\n xp = np.arange(len(s))\n x = np.insert(x, np.searchsorted(x, xp), xp) if len(s) < n else x\n segments[i] = (\n np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)], dtype=np.float32).reshape(2, -1).T\n ) # segment xy\n return segments", "chunk_type": "function", "name": "resample_segments", "file_path": "ultralytics\\ultralytics\\utils\\ops.py", "start_line": 649, "end_line": 670, "start_col": 0, "end_col": 19, "parent_name": null, "docstring": "Resample segments to n points each using linear interpolation.\n\nArgs:\n segments (list): List of (N, 2) arrays where N is the number of points in each segment.\n n (int): Number of points to resample each segment to.\n\nReturns:\n (list): Resampled segments with n points each.", "parameters": [ "segments", "n: int" ], "return_type": null, "decorators": [], "complexity_score": 4, "dependencies": [ "contextlib", "math", "re", "time", "typing.Optional", "cv2", "numpy", "torch", "torch.nn.functional", "ultralytics.utils.LOGGER", "ultralytics.utils.metrics.batch_probiou", "torchvision", "ultralytics.data.converter.merge_multi_segment" ], "chunk_id": "function_resample_segments_bf05b57c" }, { "content": "def crop_mask(masks, boxes):\n \"\"\"\n Crop masks to bounding box regions.\n\n Args:\n masks (torch.Tensor): Masks with shape (N, H, W).\n boxes (torch.Tensor): Bounding box coordinates with shape (N, 4) in relative point form.\n\n Returns:\n (torch.Tensor): Cropped masks.\n \"\"\"\n _, h, w = masks.shape\n x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1) # x1 shape(n,1,1)\n r = torch.arange(w, device=masks.device, dtype=x1.dtype)[None, None, :] # rows shape(1,1,w)\n c = torch.arange(h, device=masks.device, dtype=x1.dtype)[None, :, None] # cols shape(1,h,1)\n\n return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2))", "chunk_type": "function", "name": "crop_mask", "file_path": "ultralytics\\ultralytics\\utils\\ops.py", "start_line": 673, "end_line": 689, "start_col": 0, "end_col": 64, "parent_name": null, "docstring": "Crop masks to bounding box regions.\n\nArgs:\n masks (torch.Tensor): Masks with shape (N, H, W).\n boxes (torch.Tensor): Bounding box coordinates with shape (N, 4) in relative point form.\n\nReturns:\n (torch.Tensor): Cropped masks.", "parameters": [ "masks", "boxes" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "contextlib", "math", "re", "time", "typing.Optional", "cv2", "numpy", "torch", "torch.nn.functional", "ultralytics.utils.LOGGER", "ultralytics.utils.metrics.batch_probiou", "torchvision", "ultralytics.data.converter.merge_multi_segment" ], "chunk_id": "function_crop_mask_25034c2c" }, { "content": "def process_mask(protos, masks_in, bboxes, shape, upsample: bool = False):\n \"\"\"\n Apply masks to bounding boxes using mask head output.\n\n Args:\n protos (torch.Tensor): Mask prototypes with shape (mask_dim, mask_h, mask_w).\n masks_in (torch.Tensor): Mask coefficients with shape (N, mask_dim) where N is number of masks after NMS.\n bboxes (torch.Tensor): Bounding boxes with shape (N, 4) where N is number of masks after NMS.\n shape (tuple): Input image size as (height, width).\n upsample (bool): Whether to upsample masks to original image size.\n\n Returns:\n (torch.Tensor): A binary mask tensor of shape [n, h, w], where n is the number of masks after NMS, and h and w\n are the height and width of the input image. The mask is applied to the bounding boxes.\n \"\"\"\n c, mh, mw = protos.shape # CHW\n ih, iw = shape\n masks = (masks_in @ protos.float().view(c, -1)).view(-1, mh, mw) # CHW\n width_ratio = mw / iw\n height_ratio = mh / ih\n\n downsampled_bboxes = bboxes.clone()\n downsampled_bboxes[:, 0] *= width_ratio\n downsampled_bboxes[:, 2] *= width_ratio\n downsampled_bboxes[:, 3] *= height_ratio\n downsampled_bboxes[:, 1] *= height_ratio\n\n masks = crop_mask(masks, downsampled_bboxes) # CHW\n if upsample:\n masks = F.interpolate(masks[None], shape, mode=\"bilinear\", align_corners=False)[0] # CHW\n return masks.gt_(0.0)", "chunk_type": "function", "name": "process_mask", "file_path": "ultralytics\\ultralytics\\utils\\ops.py", "start_line": 692, "end_line": 722, "start_col": 0, "end_col": 25, "parent_name": null, "docstring": "Apply masks to bounding boxes using mask head output.\n\nArgs:\n protos (torch.Tensor): Mask prototypes with shape (mask_dim, mask_h, mask_w).\n masks_in (torch.Tensor): Mask coefficients with shape (N, mask_dim) where N is number of masks after NMS.\n bboxes (torch.Tensor): Bounding boxes with shape (N, 4) where N is number of masks after NMS.\n shape (tuple): Input image size as (height, width).\n upsample (bool): Whether to upsample masks to original image size.\n\nReturns:\n (torch.Tensor): A binary mask tensor of shape [n, h, w], where n is the number of masks after NMS, and h and w\n are the height and width of the input image. The mask is applied to the bounding boxes.", "parameters": [ "protos", "masks_in", "bboxes", "shape", "upsample: bool" ], "return_type": null, "decorators": [], "complexity_score": 2, "dependencies": [ "contextlib", "math", "re", "time", "typing.Optional", "cv2", "numpy", "torch", "torch.nn.functional", "ultralytics.utils.LOGGER", "ultralytics.utils.metrics.batch_probiou", "torchvision", "ultralytics.data.converter.merge_multi_segment" ], "chunk_id": "function_process_mask_7c305c15" }, { "content": "def process_mask_native(protos, masks_in, bboxes, shape):\n \"\"\"\n Apply masks to bounding boxes using mask head output with native upsampling.\n\n Args:\n protos (torch.Tensor): Mask prototypes with shape (mask_dim, mask_h, mask_w).\n masks_in (torch.Tensor): Mask coefficients with shape (N, mask_dim) where N is number of masks after NMS.\n bboxes (torch.Tensor): Bounding boxes with shape (N, 4) where N is number of masks after NMS.\n shape (tuple): Input image size as (height, width).\n\n Returns:\n (torch.Tensor): Binary mask tensor with shape (H, W, N).\n \"\"\"\n c, mh, mw = protos.shape # CHW\n masks = (masks_in @ protos.float().view(c, -1)).view(-1, mh, mw)\n masks = scale_masks(masks[None], shape)[0] # CHW\n masks = crop_mask(masks, bboxes) # CHW\n return masks.gt_(0.0)", "chunk_type": "function", "name": "process_mask_native", "file_path": "ultralytics\\ultralytics\\utils\\ops.py", "start_line": 725, "end_line": 742, "start_col": 0, "end_col": 25, "parent_name": null, "docstring": "Apply masks to bounding boxes using mask head output with native upsampling.\n\nArgs:\n protos (torch.Tensor): Mask prototypes with shape (mask_dim, mask_h, mask_w).\n masks_in (torch.Tensor): Mask coefficients with shape (N, mask_dim) where N is number of masks after NMS.\n bboxes (torch.Tensor): Bounding boxes with shape (N, 4) where N is number of masks after NMS.\n shape (tuple): Input image size as (height, width).\n\nReturns:\n (torch.Tensor): Binary mask tensor with shape (H, W, N).", "parameters": [ "protos", "masks_in", "bboxes", "shape" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "contextlib", "math", "re", "time", "typing.Optional", "cv2", "numpy", "torch", "torch.nn.functional", "ultralytics.utils.LOGGER", "ultralytics.utils.metrics.batch_probiou", "torchvision", "ultralytics.data.converter.merge_multi_segment" ], "chunk_id": "function_process_mask_native_556480d4" }, { "content": "def scale_masks(masks, shape, padding: bool = True):\n \"\"\"\n Rescale segment masks to target shape.\n\n Args:\n masks (torch.Tensor): Masks with shape (N, C, H, W).\n shape (tuple): Target height and width as (height, width).\n padding (bool): Whether masks are based on YOLO-style augmented images with padding.\n\n Returns:\n (torch.Tensor): Rescaled masks.\n \"\"\"\n mh, mw = masks.shape[2:]\n gain = min(mh / shape[0], mw / shape[1]) # gain = old / new\n pad = [mw - shape[1] * gain, mh - shape[0] * gain] # wh padding\n if padding:\n pad[0] /= 2\n pad[1] /= 2\n top, left = (int(round(pad[1] - 0.1)), int(round(pad[0] - 0.1))) if padding else (0, 0) # y, x\n bottom, right = (\n mh - int(round(pad[1] + 0.1)),\n mw - int(round(pad[0] + 0.1)),\n )\n masks = masks[..., top:bottom, left:right]\n\n masks = F.interpolate(masks, shape, mode=\"bilinear\", align_corners=False) # NCHW\n return masks", "chunk_type": "function", "name": "scale_masks", "file_path": "ultralytics\\ultralytics\\utils\\ops.py", "start_line": 745, "end_line": 771, "start_col": 0, "end_col": 16, "parent_name": null, "docstring": "Rescale segment masks to target shape.\n\nArgs:\n masks (torch.Tensor): Masks with shape (N, C, H, W).\n shape (tuple): Target height and width as (height, width).\n padding (bool): Whether masks are based on YOLO-style augmented images with padding.\n\nReturns:\n (torch.Tensor): Rescaled masks.", "parameters": [ "masks", "shape", "padding: bool" ], "return_type": null, "decorators": [], "complexity_score": 2, "dependencies": [ "contextlib", "math", "re", "time", "typing.Optional", "cv2", "numpy", "torch", "torch.nn.functional", "ultralytics.utils.LOGGER", "ultralytics.utils.metrics.batch_probiou", "torchvision", "ultralytics.data.converter.merge_multi_segment" ], "chunk_id": "function_scale_masks_29488e17" }, { "content": "def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize: bool = False, padding: bool = True):\n \"\"\"\n Rescale segment coordinates from img1_shape to img0_shape.\n\n Args:\n img1_shape (tuple): Shape of the source image.\n coords (torch.Tensor): Coordinates to scale with shape (N, 2).\n img0_shape (tuple): Shape of the target image.\n ratio_pad (tuple, optional): Ratio and padding values as ((ratio_h, ratio_w), (pad_h, pad_w)).\n normalize (bool): Whether to normalize coordinates to range [0, 1].\n padding (bool): Whether coordinates are based on YOLO-style augmented images with padding.\n\n Returns:\n (torch.Tensor): Scaled coordinates.\n \"\"\"\n if ratio_pad is None: # calculate from img0_shape\n gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new\n pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding\n else:\n gain = ratio_pad[0][0]\n pad = ratio_pad[1]\n\n if padding:\n coords[..., 0] -= pad[0] # x padding\n coords[..., 1] -= pad[1] # y padding\n coords[..., 0] /= gain\n coords[..., 1] /= gain\n coords = clip_coords(coords, img0_shape)\n if normalize:\n coords[..., 0] /= img0_shape[1] # width\n coords[..., 1] /= img0_shape[0] # height\n return coords", "chunk_type": "function", "name": "scale_coords", "file_path": "ultralytics\\ultralytics\\utils\\ops.py", "start_line": 774, "end_line": 805, "start_col": 0, "end_col": 17, "parent_name": null, "docstring": "Rescale segment coordinates from img1_shape to img0_shape.\n\nArgs:\n img1_shape (tuple): Shape of the source image.\n coords (torch.Tensor): Coordinates to scale with shape (N, 2).\n img0_shape (tuple): Shape of the target image.\n ratio_pad (tuple, optional): Ratio and padding values as ((ratio_h, ratio_w), (pad_h, pad_w)).\n normalize (bool): Whether to normalize coordinates to range [0, 1].\n padding (bool): Whether coordinates are based on YOLO-style augmented images with padding.\n\nReturns:\n (torch.Tensor): Scaled coordinates.", "parameters": [ "img1_shape", "coords", "img0_shape", "ratio_pad", "normalize: bool", "padding: bool" ], "return_type": null, "decorators": [], "complexity_score": 4, "dependencies": [ "contextlib", "math", "re", "time", "typing.Optional", "cv2", "numpy", "torch", "torch.nn.functional", "ultralytics.utils.LOGGER", "ultralytics.utils.metrics.batch_probiou", "torchvision", "ultralytics.data.converter.merge_multi_segment" ], "chunk_id": "function_scale_coords_af97eefd" }, { "content": "def regularize_rboxes(rboxes):\n \"\"\"\n Regularize rotated bounding boxes to range [0, pi/2].\n\n Args:\n rboxes (torch.Tensor): Input rotated boxes with shape (N, 5) in xywhr format.\n\n Returns:\n (torch.Tensor): Regularized rotated boxes.\n \"\"\"\n x, y, w, h, t = rboxes.unbind(dim=-1)\n # Swap edge if t >= pi/2 while not being symmetrically opposite\n swap = t % math.pi >= math.pi / 2\n w_ = torch.where(swap, h, w)\n h_ = torch.where(swap, w, h)\n t = t % (math.pi / 2)\n return torch.stack([x, y, w_, h_, t], dim=-1) # regularized boxes", "chunk_type": "function", "name": "regularize_rboxes", "file_path": "ultralytics\\ultralytics\\utils\\ops.py", "start_line": 808, "end_line": 824, "start_col": 0, "end_col": 49, "parent_name": null, "docstring": "Regularize rotated bounding boxes to range [0, pi/2].\n\nArgs:\n rboxes (torch.Tensor): Input rotated boxes with shape (N, 5) in xywhr format.\n\nReturns:\n (torch.Tensor): Regularized rotated boxes.", "parameters": [ "rboxes" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "contextlib", "math", "re", "time", "typing.Optional", "cv2", "numpy", "torch", "torch.nn.functional", "ultralytics.utils.LOGGER", "ultralytics.utils.metrics.batch_probiou", "torchvision", "ultralytics.data.converter.merge_multi_segment" ], "chunk_id": "function_regularize_rboxes_c10e57ca" }, { "content": "def masks2segments(masks, strategy: str = \"all\"):\n \"\"\"\n Convert masks to segments using contour detection.\n\n Args:\n masks (torch.Tensor): Binary masks with shape (batch_size, 160, 160).\n strategy (str): Segmentation strategy, either 'all' or 'largest'.\n\n Returns:\n (list): List of segment masks as float32 arrays.\n \"\"\"\n from ultralytics.data.converter import merge_multi_segment\n\n segments = []\n for x in masks.int().cpu().numpy().astype(\"uint8\"):\n c = cv2.findContours(x, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]\n if c:\n if strategy == \"all\": # merge and concatenate all segments\n c = (\n np.concatenate(merge_multi_segment([x.reshape(-1, 2) for x in c]))\n if len(c) > 1\n else c[0].reshape(-1, 2)\n )\n elif strategy == \"largest\": # select largest segment\n c = np.array(c[np.array([len(x) for x in c]).argmax()]).reshape(-1, 2)\n else:\n c = np.zeros((0, 2)) # no segments found\n segments.append(c.astype(\"float32\"))\n return segments", "chunk_type": "function", "name": "masks2segments", "file_path": "ultralytics\\ultralytics\\utils\\ops.py", "start_line": 827, "end_line": 855, "start_col": 0, "end_col": 19, "parent_name": null, "docstring": "Convert masks to segments using contour detection.\n\nArgs:\n masks (torch.Tensor): Binary masks with shape (batch_size, 160, 160).\n strategy (str): Segmentation strategy, either 'all' or 'largest'.\n\nReturns:\n (list): List of segment masks as float32 arrays.", "parameters": [ "masks", "strategy: str" ], "return_type": null, "decorators": [], "complexity_score": 7, "dependencies": [ "contextlib", "math", "re", "time", "typing.Optional", "cv2", "numpy", "torch", "torch.nn.functional", "ultralytics.utils.LOGGER", "ultralytics.utils.metrics.batch_probiou", "torchvision", "ultralytics.data.converter.merge_multi_segment" ], "chunk_id": "function_masks2segments_1e64c283" }, { "content": "def convert_torch2numpy_batch(batch: torch.Tensor) -> np.ndarray:\n \"\"\"\n Convert a batch of FP32 torch tensors to NumPy uint8 arrays, changing from BCHW to BHWC layout.\n\n Args:\n batch (torch.Tensor): Input tensor batch with shape (Batch, Channels, Height, Width) and dtype torch.float32.\n\n Returns:\n (np.ndarray): Output NumPy array batch with shape (Batch, Height, Width, Channels) and dtype uint8.\n \"\"\"\n return (batch.permute(0, 2, 3, 1).contiguous() * 255).clamp(0, 255).to(torch.uint8).cpu().numpy()", "chunk_type": "function", "name": "convert_torch2numpy_batch", "file_path": "ultralytics\\ultralytics\\utils\\ops.py", "start_line": 858, "end_line": 868, "start_col": 0, "end_col": 101, "parent_name": null, "docstring": "Convert a batch of FP32 torch tensors to NumPy uint8 arrays, changing from BCHW to BHWC layout.\n\nArgs:\n batch (torch.Tensor): Input tensor batch with shape (Batch, Channels, Height, Width) and dtype torch.float32.\n\nReturns:\n (np.ndarray): Output NumPy array batch with shape (Batch, Height, Width, Channels) and dtype uint8.", "parameters": [ "batch: torch.Tensor" ], "return_type": "np.ndarray", "decorators": [], "complexity_score": 1, "dependencies": [ "contextlib", "math", "re", "time", "typing.Optional", "cv2", "numpy", "torch", "torch.nn.functional", "ultralytics.utils.LOGGER", "ultralytics.utils.metrics.batch_probiou", "torchvision", "ultralytics.data.converter.merge_multi_segment" ], "chunk_id": "function_convert_torch2numpy_batch_47ba17e5" }, { "content": "def clean_str(s):\n \"\"\"\n Clean a string by replacing special characters with '_' character.\n\n Args:\n s (str): A string needing special characters replaced.\n\n Returns:\n (str): A string with special characters replaced by an underscore _.\n \"\"\"\n return re.sub(pattern=\"[|@#!¡·$€%&()=?¿^*;:,¨´><+]\", repl=\"_\", string=s)", "chunk_type": "function", "name": "clean_str", "file_path": "ultralytics\\ultralytics\\utils\\ops.py", "start_line": 871, "end_line": 881, "start_col": 0, "end_col": 83, "parent_name": null, "docstring": "Clean a string by replacing special characters with '_' character.\n\nArgs:\n s (str): A string needing special characters replaced.\n\nReturns:\n (str): A string with special characters replaced by an underscore _.", "parameters": [ "s" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "contextlib", "math", "re", "time", "typing.Optional", "cv2", "numpy", "torch", "torch.nn.functional", "ultralytics.utils.LOGGER", "ultralytics.utils.metrics.batch_probiou", "torchvision", "ultralytics.data.converter.merge_multi_segment" ], "chunk_id": "function_clean_str_119c2be3" }, { "content": "def empty_like(x):\n \"\"\"Create empty torch.Tensor or np.ndarray with same shape as input and float32 dtype.\"\"\"\n return (\n torch.empty_like(x, dtype=torch.float32) if isinstance(x, torch.Tensor) else np.empty_like(x, dtype=np.float32)\n )", "chunk_type": "function", "name": "empty_like", "file_path": "ultralytics\\ultralytics\\utils\\ops.py", "start_line": 884, "end_line": 888, "start_col": 0, "end_col": 5, "parent_name": null, "docstring": "Create empty torch.Tensor or np.ndarray with same shape as input and float32 dtype.", "parameters": [ "x" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "contextlib", "math", "re", "time", "typing.Optional", "cv2", "numpy", "torch", "torch.nn.functional", "ultralytics.utils.LOGGER", "ultralytics.utils.metrics.batch_probiou", "torchvision", "ultralytics.data.converter.merge_multi_segment" ], "chunk_id": "function_empty_like_35cd77b9" }, { "content": "import time", "chunk_type": "import", "name": "time", "file_path": "ultralytics\\ultralytics\\utils\\patches.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 11, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_time_9e1c91d6" }, { "content": "from contextlib import contextmanager", "chunk_type": "import", "name": "contextmanager", "file_path": "ultralytics\\ultralytics\\utils\\patches.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 37, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_contextmanager_860a4511" }, { "content": "from copy import copy", "chunk_type": "import", "name": "copy", "file_path": "ultralytics\\ultralytics\\utils\\patches.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 21, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_copy_5b142c7f" }, { "content": "from pathlib import Path", "chunk_type": "import", "name": "Path", "file_path": "ultralytics\\ultralytics\\utils\\patches.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Path_e82f0c22" }, { "content": "from typing import Any, Dict, List, Optional", "chunk_type": "import", "name": "Any, Dict, List, Optional", "file_path": "ultralytics\\ultralytics\\utils\\patches.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 44, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Any, Dict, List, Optional_3383557e" }, { "content": "import cv2", "chunk_type": "import", "name": "cv2", "file_path": "ultralytics\\ultralytics\\utils\\patches.py", "start_line": 10, "end_line": 10, "start_col": 0, "end_col": 10, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_cv2_c903a9b5" }, { "content": "import numpy as np", "chunk_type": "import", "name": "numpy", "file_path": "ultralytics\\ultralytics\\utils\\patches.py", "start_line": 11, "end_line": 11, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_numpy_ff8c444e" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\ultralytics\\utils\\patches.py", "start_line": 12, "end_line": 12, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_57df7b1e" }, { "content": "_imshow = cv2.imshow # copy to avoid recursion errors", "chunk_type": "variable", "name": "_imshow", "file_path": "ultralytics\\ultralytics\\utils\\patches.py", "start_line": 15, "end_line": 15, "start_col": 0, "end_col": 20, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable__imshow_015f4d04" }, { "content": "def imread(filename: str, flags: int = cv2.IMREAD_COLOR) -> Optional[np.ndarray]:\n \"\"\"\n Read an image from a file with multilanguage filename support.\n\n Args:\n filename (str): Path to the file to read.\n flags (int, optional): Flag that can take values of cv2.IMREAD_*. Controls how the image is read.\n\n Returns:\n (np.ndarray | None): The read image array, or None if reading fails.\n\n Examples:\n >>> img = imread(\"path/to/image.jpg\")\n >>> img = imread(\"path/to/image.jpg\", cv2.IMREAD_GRAYSCALE)\n \"\"\"\n file_bytes = np.fromfile(filename, np.uint8)\n if filename.endswith((\".tiff\", \".tif\")):\n success, frames = cv2.imdecodemulti(file_bytes, cv2.IMREAD_UNCHANGED)\n if success:\n # Handle RGB images in tif/tiff format\n return frames[0] if len(frames) == 1 and frames[0].ndim == 3 else np.stack(frames, axis=2)\n return None\n else:\n im = cv2.imdecode(file_bytes, flags)\n return im[..., None] if im.ndim == 2 else im # Always ensure 3 dimensions", "chunk_type": "function", "name": "imread", "file_path": "ultralytics\\ultralytics\\utils\\patches.py", "start_line": 18, "end_line": 42, "start_col": 0, "end_col": 52, "parent_name": null, "docstring": "Read an image from a file with multilanguage filename support.\n\nArgs:\n filename (str): Path to the file to read.\n flags (int, optional): Flag that can take values of cv2.IMREAD_*. Controls how the image is read.\n\nReturns:\n (np.ndarray | None): The read image array, or None if reading fails.\n\nExamples:\n >>> img = imread(\"path/to/image.jpg\")\n >>> img = imread(\"path/to/image.jpg\", cv2.IMREAD_GRAYSCALE)", "parameters": [ "filename: str", "flags: int" ], "return_type": "Optional[np.ndarray]", "decorators": [], "complexity_score": 3, "dependencies": [ "time", "contextlib.contextmanager", "copy.copy", "pathlib.Path", "typing.Any", "typing.Dict", "typing.List", "typing.Optional", "cv2", "numpy", "torch", "ultralytics.utils.torch_utils.TORCH_1_13" ], "chunk_id": "function_imread_53ca1577" }, { "content": "def imwrite(filename: str, img: np.ndarray, params: Optional[List[int]] = None) -> bool:\n \"\"\"\n Write an image to a file with multilanguage filename support.\n\n Args:\n filename (str): Path to the file to write.\n img (np.ndarray): Image to write.\n params (List[int], optional): Additional parameters for image encoding.\n\n Returns:\n (bool): True if the file was written successfully, False otherwise.\n\n Examples:\n >>> import numpy as np\n >>> img = np.zeros((100, 100, 3), dtype=np.uint8) # Create a black image\n >>> success = imwrite(\"output.jpg\", img) # Write image to file\n >>> print(success)\n True\n \"\"\"\n try:\n cv2.imencode(Path(filename).suffix, img, params)[1].tofile(filename)\n return True\n except Exception:\n return False", "chunk_type": "function", "name": "imwrite", "file_path": "ultralytics\\ultralytics\\utils\\patches.py", "start_line": 45, "end_line": 68, "start_col": 0, "end_col": 20, "parent_name": null, "docstring": "Write an image to a file with multilanguage filename support.\n\nArgs:\n filename (str): Path to the file to write.\n img (np.ndarray): Image to write.\n params (List[int], optional): Additional parameters for image encoding.\n\nReturns:\n (bool): True if the file was written successfully, False otherwise.\n\nExamples:\n >>> import numpy as np\n >>> img = np.zeros((100, 100, 3), dtype=np.uint8) # Create a black image\n >>> success = imwrite(\"output.jpg\", img) # Write image to file\n >>> print(success)\n True", "parameters": [ "filename: str", "img: np.ndarray", "params: Optional[List[int]]" ], "return_type": "bool", "decorators": [], "complexity_score": 2, "dependencies": [ "time", "contextlib.contextmanager", "copy.copy", "pathlib.Path", "typing.Any", "typing.Dict", "typing.List", "typing.Optional", "cv2", "numpy", "torch", "ultralytics.utils.torch_utils.TORCH_1_13" ], "chunk_id": "function_imwrite_bfb76116" }, { "content": "def imshow(winname: str, mat: np.ndarray) -> None:\n \"\"\"\n Display an image in the specified window with multilanguage window name support.\n\n This function is a wrapper around OpenCV's imshow function that displays an image in a named window. It handles\n multilanguage window names by encoding them properly for OpenCV compatibility.\n\n Args:\n winname (str): Name of the window where the image will be displayed. If a window with this name already\n exists, the image will be displayed in that window.\n mat (np.ndarray): Image to be shown. Should be a valid numpy array representing an image.\n\n Examples:\n >>> import numpy as np\n >>> img = np.zeros((300, 300, 3), dtype=np.uint8) # Create a black image\n >>> img[:100, :100] = [255, 0, 0] # Add a blue square\n >>> imshow(\"Example Window\", img) # Display the image\n \"\"\"\n _imshow(winname.encode(\"unicode_escape\").decode(), mat)", "chunk_type": "function", "name": "imshow", "file_path": "ultralytics\\ultralytics\\utils\\patches.py", "start_line": 71, "end_line": 89, "start_col": 0, "end_col": 59, "parent_name": null, "docstring": "Display an image in the specified window with multilanguage window name support.\n\nThis function is a wrapper around OpenCV's imshow function that displays an image in a named window. It handles\nmultilanguage window names by encoding them properly for OpenCV compatibility.\n\nArgs:\n winname (str): Name of the window where the image will be displayed. If a window with this name already\n exists, the image will be displayed in that window.\n mat (np.ndarray): Image to be shown. Should be a valid numpy array representing an image.\n\nExamples:\n >>> import numpy as np\n >>> img = np.zeros((300, 300, 3), dtype=np.uint8) # Create a black image\n >>> img[:100, :100] = [255, 0, 0] # Add a blue square\n >>> imshow(\"Example Window\", img) # Display the image", "parameters": [ "winname: str", "mat: np.ndarray" ], "return_type": "None", "decorators": [], "complexity_score": 1, "dependencies": [ "time", "contextlib.contextmanager", "copy.copy", "pathlib.Path", "typing.Any", "typing.Dict", "typing.List", "typing.Optional", "cv2", "numpy", "torch", "ultralytics.utils.torch_utils.TORCH_1_13" ], "chunk_id": "function_imshow_a3da89bb" }, { "content": "_torch_save = torch.save", "chunk_type": "variable", "name": "_torch_save", "file_path": "ultralytics\\ultralytics\\utils\\patches.py", "start_line": 93, "end_line": 93, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable__torch_save_18ff7425" }, { "content": "def torch_load(*args, **kwargs):\n \"\"\"\n Load a PyTorch model with updated arguments to avoid warnings.\n\n This function wraps torch.load and adds the 'weights_only' argument for PyTorch 1.13.0+ to prevent warnings.\n\n Args:\n *args (Any): Variable length argument list to pass to torch.load.\n **kwargs (Any): Arbitrary keyword arguments to pass to torch.load.\n\n Returns:\n (Any): The loaded PyTorch object.\n\n Notes:\n For PyTorch versions 2.0 and above, this function automatically sets 'weights_only=False'\n if the argument is not provided, to avoid deprecation warnings.\n \"\"\"\n from ultralytics.utils.torch_utils import TORCH_1_13\n\n if TORCH_1_13 and \"weights_only\" not in kwargs:\n kwargs[\"weights_only\"] = False\n\n return torch.load(*args, **kwargs)", "chunk_type": "function", "name": "torch_load", "file_path": "ultralytics\\ultralytics\\utils\\patches.py", "start_line": 96, "end_line": 118, "start_col": 0, "end_col": 38, "parent_name": null, "docstring": "Load a PyTorch model with updated arguments to avoid warnings.\n\nThis function wraps torch.load and adds the 'weights_only' argument for PyTorch 1.13.0+ to prevent warnings.\n\nArgs:\n *args (Any): Variable length argument list to pass to torch.load.\n **kwargs (Any): Arbitrary keyword arguments to pass to torch.load.\n\nReturns:\n (Any): The loaded PyTorch object.\n\nNotes:\n For PyTorch versions 2.0 and above, this function automatically sets 'weights_only=False'\n if the argument is not provided, to avoid deprecation warnings.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 2, "dependencies": [ "time", "contextlib.contextmanager", "copy.copy", "pathlib.Path", "typing.Any", "typing.Dict", "typing.List", "typing.Optional", "cv2", "numpy", "torch", "ultralytics.utils.torch_utils.TORCH_1_13" ], "chunk_id": "function_torch_load_8dfa9a14" }, { "content": "def torch_save(*args, **kwargs):\n \"\"\"\n Save PyTorch objects with retry mechanism for robustness.\n\n This function wraps torch.save with 3 retries and exponential backoff in case of save failures, which can occur\n due to device flushing delays or antivirus scanning.\n\n Args:\n *args (Any): Positional arguments to pass to torch.save.\n **kwargs (Any): Keyword arguments to pass to torch.save.\n\n Examples:\n >>> model = torch.nn.Linear(10, 1)\n >>> torch_save(model.state_dict(), \"model.pt\")\n \"\"\"\n for i in range(4): # 3 retries\n try:\n return _torch_save(*args, **kwargs)\n except RuntimeError as e: # Unable to save, possibly waiting for device to flush or antivirus scan\n if i == 3:\n raise e\n time.sleep((2**i) / 2) # Exponential backoff: 0.5s, 1.0s, 2.0s", "chunk_type": "function", "name": "torch_save", "file_path": "ultralytics\\ultralytics\\utils\\patches.py", "start_line": 121, "end_line": 142, "start_col": 0, "end_col": 34, "parent_name": null, "docstring": "Save PyTorch objects with retry mechanism for robustness.\n\nThis function wraps torch.save with 3 retries and exponential backoff in case of save failures, which can occur\ndue to device flushing delays or antivirus scanning.\n\nArgs:\n *args (Any): Positional arguments to pass to torch.save.\n **kwargs (Any): Keyword arguments to pass to torch.save.\n\nExamples:\n >>> model = torch.nn.Linear(10, 1)\n >>> torch_save(model.state_dict(), \"model.pt\")", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 4, "dependencies": [ "time", "contextlib.contextmanager", "copy.copy", "pathlib.Path", "typing.Any", "typing.Dict", "typing.List", "typing.Optional", "cv2", "numpy", "torch", "ultralytics.utils.torch_utils.TORCH_1_13" ], "chunk_id": "function_torch_save_48c194ec" }, { "content": "def arange_patch(args):\n \"\"\"\n Workaround for ONNX torch.arange incompatibility with FP16.\n\n https://github.com/pytorch/pytorch/issues/148041.\n \"\"\"\n if args.dynamic and args.half and args.format == \"onnx\":\n func = torch.arange\n\n def arange(*args, dtype=None, **kwargs):\n \"\"\"Return a 1-D tensor of size with values from the interval and common difference.\"\"\"\n return func(*args, **kwargs).to(dtype) # cast to dtype instead of passing dtype\n\n torch.arange = arange # patch\n yield\n torch.arange = func # unpatch\n else:\n yield", "chunk_type": "function", "name": "arange_patch", "file_path": "ultralytics\\ultralytics\\utils\\patches.py", "start_line": 146, "end_line": 163, "start_col": 0, "end_col": 13, "parent_name": null, "docstring": "Workaround for ONNX torch.arange incompatibility with FP16.\n\nhttps://github.com/pytorch/pytorch/issues/148041.", "parameters": [ "args" ], "return_type": null, "decorators": [ "contextmanager" ], "complexity_score": 2, "dependencies": [ "time", "contextlib.contextmanager", "copy.copy", "pathlib.Path", "typing.Any", "typing.Dict", "typing.List", "typing.Optional", "cv2", "numpy", "torch", "ultralytics.utils.torch_utils.TORCH_1_13" ], "chunk_id": "function_arange_patch_143727d9" }, { "content": "def override_configs(args, overrides: Optional[Dict[str, Any]] = None):\n \"\"\"\n Context manager to temporarily override configurations in args.\n\n Args:\n args (IterableSimpleNamespace): Original configuration arguments.\n overrides (Dict[str, Any]): Dictionary of overrides to apply.\n\n Yields:\n (IterableSimpleNamespace): Configuration arguments with overrides applied.\n \"\"\"\n if overrides:\n original_args = copy(args)\n for key, value in overrides.items():\n setattr(args, key, value)\n try:\n yield args\n finally:\n args.__dict__.update(original_args.__dict__)\n else:\n yield args", "chunk_type": "function", "name": "override_configs", "file_path": "ultralytics\\ultralytics\\utils\\patches.py", "start_line": 167, "end_line": 187, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": "Context manager to temporarily override configurations in args.\n\nArgs:\n args (IterableSimpleNamespace): Original configuration arguments.\n overrides (Dict[str, Any]): Dictionary of overrides to apply.\n\nYields:\n (IterableSimpleNamespace): Configuration arguments with overrides applied.", "parameters": [ "args", "overrides: Optional[Dict[str, Any]]" ], "return_type": null, "decorators": [ "contextmanager" ], "complexity_score": 3, "dependencies": [ "time", "contextlib.contextmanager", "copy.copy", "pathlib.Path", "typing.Any", "typing.Dict", "typing.List", "typing.Optional", "cv2", "numpy", "torch", "ultralytics.utils.torch_utils.TORCH_1_13" ], "chunk_id": "function_override_configs_29af933f" }, { "content": "import math", "chunk_type": "import", "name": "math", "file_path": "ultralytics\\ultralytics\\utils\\plotting.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 11, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_math_4cce6133" }, { "content": "import warnings", "chunk_type": "import", "name": "warnings", "file_path": "ultralytics\\ultralytics\\utils\\plotting.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 15, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_warnings_8fac1370" }, { "content": "from pathlib import Path", "chunk_type": "import", "name": "Path", "file_path": "ultralytics\\ultralytics\\utils\\plotting.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Path_fd97ea76" }, { "content": "from typing import Any, Callable, Dict, List, Optional, Union", "chunk_type": "import", "name": "Any, Callable, Dict, List, Optional, Union", "file_path": "ultralytics\\ultralytics\\utils\\plotting.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 61, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Any, Callable, Dict, List, Optional, Union_297079a9" }, { "content": "import cv2", "chunk_type": "import", "name": "cv2", "file_path": "ultralytics\\ultralytics\\utils\\plotting.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 10, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_cv2_e3ab8784" }, { "content": "import numpy as np", "chunk_type": "import", "name": "numpy", "file_path": "ultralytics\\ultralytics\\utils\\plotting.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_numpy_3015f2ee" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\ultralytics\\utils\\plotting.py", "start_line": 10, "end_line": 10, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_0baed0fc" }, { "content": "from PIL import Image, ImageDraw, ImageFont", "chunk_type": "import", "name": "Image, ImageDraw, ImageFont", "file_path": "ultralytics\\ultralytics\\utils\\plotting.py", "start_line": 11, "end_line": 11, "start_col": 0, "end_col": 43, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Image, ImageDraw, ImageFont_d1d0340f" }, { "content": "from PIL import __version__ as pil_version", "chunk_type": "import", "name": "__version__", "file_path": "ultralytics\\ultralytics\\utils\\plotting.py", "start_line": 12, "end_line": 12, "start_col": 0, "end_col": 42, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import___version___08b5082a" }, { "content": "from ultralytics.utils import IS_COLAB, IS_KAGGLE, LOGGER, TryExcept, ops, plt_settings, threaded", "chunk_type": "import", "name": "IS_COLAB, IS_KAGGLE, LOGGER, TryExcept, ops, plt_settings, threaded", "file_path": "ultralytics\\ultralytics\\utils\\plotting.py", "start_line": 14, "end_line": 14, "start_col": 0, "end_col": 97, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_IS_COLAB, IS_KAGGLE, LOGGER, TryExcept, ops, plt_settings, threaded_bccb5e0f" }, { "content": "from ultralytics.utils.checks import check_font, check_version, is_ascii", "chunk_type": "import", "name": "check_font, check_version, is_ascii", "file_path": "ultralytics\\ultralytics\\utils\\plotting.py", "start_line": 15, "end_line": 15, "start_col": 0, "end_col": 72, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_check_font, check_version, is_ascii_78665ee3" }, { "content": "from ultralytics.utils.files import increment_path", "chunk_type": "import", "name": "increment_path", "file_path": "ultralytics\\ultralytics\\utils\\plotting.py", "start_line": 16, "end_line": 16, "start_col": 0, "end_col": 50, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_increment_path_6587a485" }, { "content": "class Colors:\n \"\"\"\n Ultralytics color palette for visualization and plotting.\n\n This class provides methods to work with the Ultralytics color palette, including converting hex color codes to\n RGB values and accessing predefined color schemes for object detection and pose estimation.\n\n Attributes:\n palette (List[tuple]): List of RGB color tuples for general use.\n n (int): The number of colors in the palette.\n pose_palette (np.ndarray): A specific color palette array for pose estimation with dtype np.uint8.\n\n Examples:\n >>> from ultralytics.utils.plotting import Colors\n >>> colors = Colors()\n >>> colors(5, True) # Returns BGR format: (221, 111, 255)\n >>> colors(5, False) # Returns RGB format: (255, 111, 221)\n\n ## Ultralytics Color Palette\n\n | Index | Color | HEX | RGB |\n |-------|-------------------------------------------------------------------|-----------|-------------------|\n | 0 | | `#042aff` | (4, 42, 255) |\n | 1 | | `#0bdbeb` | (11, 219, 235) |\n | 2 | | `#f3f3f3` | (243, 243, 243) |\n | 3 | | `#00dfb7` | (0, 223, 183) |\n | 4 | | `#111f68` | (17, 31, 104) |\n | 5 | | `#ff6fdd` | (255, 111, 221) |\n | 6 | | `#ff444f` | (255, 68, 79) |\n | 7 | | `#cced00` | (204, 237, 0) |\n | 8 | | `#00f344` | (0, 243, 68) |\n | 9 | | `#bd00ff` | (189, 0, 255) |\n | 10 | | `#00b4ff` | (0, 180, 255) |\n | 11 | | `#dd00ba` | (221, 0, 186) |\n | 12 | | `#00ffff` | (0, 255, 255) |\n | 13 | | `#26c000` | (38, 192, 0) |\n | 14 | | `#01ffb3` | (1, 255, 179) |\n | 15 | | `#7d24ff` | (125, 36, 255) |\n | 16 | | `#7b0068` | (123, 0, 104) |\n | 17 | | `#ff1b6c` | (255, 27, 108) |\n | 18 | | `#fc6d2f` | (252, 109, 47) |\n | 19 | | `#a2ff0b` | (162, 255, 11) |\n\n ## Pose Color Palette\n\n | Index | Color | HEX | RGB |\n |-------|-------------------------------------------------------------------|-----------|-------------------|\n | 0 | | `#ff8000` | (255, 128, 0) |\n | 1 | | `#ff9933` | (255, 153, 51) |\n | 2 | | `#ffb266` | (255, 178, 102) |\n | 3 | | `#e6e600` | (230, 230, 0) |\n | 4 | | `#ff99ff` | (255, 153, 255) |\n | 5 | | `#99ccff` | (153, 204, 255) |\n | 6 | | `#ff66ff` | (255, 102, 255) |\n | 7 | | `#ff33ff` | (255, 51, 255) |\n | 8 | | `#66b2ff` | (102, 178, 255) |\n | 9 | | `#3399ff` | (51, 153, 255) |\n | 10 | | `#ff9999` | (255, 153, 153) |\n | 11 | | `#ff6666` | (255, 102, 102) |\n | 12 | | `#ff3333` | (255, 51, 51) |\n | 13 | | `#99ff99` | (153, 255, 153) |\n | 14 | | `#66ff66` | (102, 255, 102) |\n | 15 | | `#33ff33` | (51, 255, 51) |\n | 16 | | `#00ff00` | (0, 255, 0) |\n | 17 | | `#0000ff` | (0, 0, 255) |\n | 18 | | `#ff0000` | (255, 0, 0) |\n | 19 | | `#ffffff` | (255, 255, 255) |\n\n !!! note \"Ultralytics Brand Colors\"\n\n For Ultralytics brand colors see [https://www.ultralytics.com/brand](https://www.ultralytics.com/brand).\n Please use the official Ultralytics colors for all marketing materials.\n \"\"\"\n\n def __init__(self):\n \"\"\"Initialize colors as hex = matplotlib.colors.TABLEAU_COLORS.values().\"\"\"\n hexs = (\n \"042AFF\",\n \"0BDBEB\",\n \"F3F3F3\",\n \"00DFB7\",\n \"111F68\",\n \"FF6FDD\",\n \"FF444F\",\n \"CCED00\",\n \"00F344\",\n \"BD00FF\",\n \"00B4FF\",\n \"DD00BA\",\n \"00FFFF\",\n \"26C000\",\n \"01FFB3\",\n \"7D24FF\",\n \"7B0068\",\n \"FF1B6C\",\n \"FC6D2F\",\n \"A2FF0B\",\n )\n self.palette = [self.hex2rgb(f\"#{c}\") for c in hexs]\n self.n = len(self.palette)\n self.pose_palette = np.array(\n [\n [255, 128, 0],\n [255, 153, 51],\n [255, 178, 102],\n [230, 230, 0],\n [255, 153, 255],\n [153, 204, 255],\n [255, 102, 255],\n [255, 51, 255],\n [102, 178, 255],\n [51, 153, 255],\n [255, 153, 153],\n [255, 102, 102],\n [255, 51, 51],\n [153, 255, 153],\n [102, 255, 102],\n [51, 255, 51],\n [0, 255, 0],\n [0, 0, 255],\n [255, 0, 0],\n [255, 255, 255],\n ],\n dtype=np.uint8,\n )\n\n def __call__(self, i: int, bgr: bool = False) -> tuple:\n \"\"\"\n Convert hex color codes to RGB values.\n\n Args:\n i (int): Color index.\n bgr (bool, optional): Whether to return BGR format instead of RGB.\n\n Returns:\n (tuple): RGB or BGR color tuple.\n \"\"\"\n c = self.palette[int(i) % self.n]\n return (c[2], c[1], c[0]) if bgr else c\n\n @staticmethod\n def hex2rgb(h: str) -> tuple:\n \"\"\"Convert hex color codes to RGB values (i.e. default PIL order).\"\"\"\n return tuple(int(h[1 + i : 1 + i + 2], 16) for i in (0, 2, 4))", "chunk_type": "class", "name": "Colors", "file_path": "ultralytics\\ultralytics\\utils\\plotting.py", "start_line": 19, "end_line": 162, "start_col": 0, "end_col": 70, "parent_name": null, "docstring": "Ultralytics color palette for visualization and plotting.\n\nThis class provides methods to work with the Ultralytics color palette, including converting hex color codes to\nRGB values and accessing predefined color schemes for object detection and pose estimation.\n\nAttributes:\n palette (List[tuple]): List of RGB color tuples for general use.\n n (int): The number of colors in the palette.\n pose_palette (np.ndarray): A specific color palette array for pose estimation with dtype np.uint8.\n\nExamples:\n >>> from ultralytics.utils.plotting import Colors\n >>> colors = Colors()\n >>> colors(5, True) # Returns BGR format: (221, 111, 255)\n >>> colors(5, False) # Returns RGB format: (255, 111, 221)\n\n## Ultralytics Color Palette\n\n| Index | Color | HEX | RGB |\n|-------|-------------------------------------------------------------------|-----------|-------------------|\n| 0 | | `#042aff` | (4, 42, 255) |\n| 1 | | `#0bdbeb` | (11, 219, 235) |\n| 2 | | `#f3f3f3` | (243, 243, 243) |\n| 3 | | `#00dfb7` | (0, 223, 183) |\n| 4 | | `#111f68` | (17, 31, 104) |\n| 5 | | `#ff6fdd` | (255, 111, 221) |\n| 6 | | `#ff444f` | (255, 68, 79) |\n| 7 | | `#cced00` | (204, 237, 0) |\n| 8 | | `#00f344` | (0, 243, 68) |\n| 9 | | `#bd00ff` | (189, 0, 255) |\n| 10 | | `#00b4ff` | (0, 180, 255) |\n| 11 | | `#dd00ba` | (221, 0, 186) |\n| 12 | | `#00ffff` | (0, 255, 255) |\n| 13 | | `#26c000` | (38, 192, 0) |\n| 14 | | `#01ffb3` | (1, 255, 179) |\n| 15 | | `#7d24ff` | (125, 36, 255) |\n| 16 | | `#7b0068` | (123, 0, 104) |\n| 17 | | `#ff1b6c` | (255, 27, 108) |\n| 18 | | `#fc6d2f` | (252, 109, 47) |\n| 19 | | `#a2ff0b` | (162, 255, 11) |\n\n## Pose Color Palette\n\n| Index | Color | HEX | RGB |\n|-------|-------------------------------------------------------------------|-----------|-------------------|\n| 0 | | `#ff8000` | (255, 128, 0) |\n| 1 | | `#ff9933` | (255, 153, 51) |\n| 2 | | `#ffb266` | (255, 178, 102) |\n| 3 | | `#e6e600` | (230, 230, 0) |\n| 4 | | `#ff99ff` | (255, 153, 255) |\n| 5 | | `#99ccff` | (153, 204, 255) |\n| 6 | | `#ff66ff` | (255, 102, 255) |\n| 7 | | `#ff33ff` | (255, 51, 255) |\n| 8 | | `#66b2ff` | (102, 178, 255) |\n| 9 | | `#3399ff` | (51, 153, 255) |\n| 10 | | `#ff9999` | (255, 153, 153) |\n| 11 | | `#ff6666` | (255, 102, 102) |\n| 12 | | `#ff3333` | (255, 51, 51) |\n| 13 | | `#99ff99` | (153, 255, 153) |\n| 14 | | `#66ff66` | (102, 255, 102) |\n| 15 | | `#33ff33` | (51, 255, 51) |\n| 16 | | `#00ff00` | (0, 255, 0) |\n| 17 | | `#0000ff` | (0, 0, 255) |\n| 18 | | `#ff0000` | (255, 0, 0) |\n| 19 | | `#ffffff` | (255, 255, 255) |\n\n!!! note \"Ultralytics Brand Colors\"\n\n For Ultralytics brand colors see [https://www.ultralytics.com/brand](https://www.ultralytics.com/brand).\n Please use the official Ultralytics colors for all marketing materials.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "math", "warnings", "pathlib.Path", "typing.Any", "typing.Callable", "typing.Dict", "typing.List", "typing.Optional", "typing.Union", "cv2", "numpy", "torch", "PIL.Image", "PIL.ImageDraw", "PIL.ImageFont", "PIL.__version__", "ultralytics.utils.IS_COLAB", "ultralytics.utils.IS_KAGGLE", "ultralytics.utils.LOGGER", "ultralytics.utils.TryExcept", "ultralytics.utils.ops", "ultralytics.utils.plt_settings", "ultralytics.utils.threaded", "ultralytics.utils.checks.check_font", "ultralytics.utils.checks.check_version", "ultralytics.utils.checks.is_ascii", "ultralytics.utils.files.increment_path", "matplotlib.pyplot", "pandas", "matplotlib.colors.LinearSegmentedColormap", "matplotlib.pyplot", "pandas", "scipy.ndimage.gaussian_filter1d", "matplotlib.pyplot", "matplotlib.pyplot", "pandas", "scipy.ndimage.gaussian_filter1d", "matplotlib.pyplot", "seaborn" ], "chunk_id": "class_Colors_52a994d8" }, { "content": "colors = Colors() # create instance for 'from utils.plots import colors'", "chunk_type": "variable", "name": "colors", "file_path": "ultralytics\\ultralytics\\utils\\plotting.py", "start_line": 165, "end_line": 165, "start_col": 0, "end_col": 17, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_colors_0f7f05d2" }, { "content": "class Annotator:\n \"\"\"\n Ultralytics Annotator for train/val mosaics and JPGs and predictions annotations.\n\n Attributes:\n im (Image.Image | np.ndarray): The image to annotate.\n pil (bool): Whether to use PIL or cv2 for drawing annotations.\n font (ImageFont.truetype | ImageFont.load_default): Font used for text annotations.\n lw (float): Line width for drawing.\n skeleton (List[List[int]]): Skeleton structure for keypoints.\n limb_color (List[int]): Color palette for limbs.\n kpt_color (List[int]): Color palette for keypoints.\n dark_colors (set): Set of colors considered dark for text contrast.\n light_colors (set): Set of colors considered light for text contrast.\n\n Examples:\n >>> from ultralytics.utils.plotting import Annotator\n >>> im0 = cv2.imread(\"test.png\")\n >>> annotator = Annotator(im0, line_width=10)\n >>> annotator.box_label([10, 10, 100, 100], \"person\", (255, 0, 0))\n \"\"\"\n\n def __init__(\n self,\n im,\n line_width: Optional[int] = None,\n font_size: Optional[int] = None,\n font: str = \"Arial.ttf\",\n pil: bool = False,\n example: str = \"abc\",\n ):\n \"\"\"Initialize the Annotator class with image and line width along with color palette for keypoints and limbs.\"\"\"\n non_ascii = not is_ascii(example) # non-latin labels, i.e. asian, arabic, cyrillic\n input_is_pil = isinstance(im, Image.Image)\n self.pil = pil or non_ascii or input_is_pil\n self.lw = line_width or max(round(sum(im.size if input_is_pil else im.shape) / 2 * 0.003), 2)\n if not input_is_pil:\n if im.shape[2] == 1: # handle grayscale\n im = cv2.cvtColor(im, cv2.COLOR_GRAY2BGR)\n elif im.shape[2] > 3: # multispectral\n im = np.ascontiguousarray(im[..., :3])\n if self.pil: # use PIL\n self.im = im if input_is_pil else Image.fromarray(im)\n if self.im.mode not in {\"RGB\", \"RGBA\"}: # multispectral\n self.im = self.im.convert(\"RGB\")\n self.draw = ImageDraw.Draw(self.im, \"RGBA\")\n try:\n font = check_font(\"Arial.Unicode.ttf\" if non_ascii else font)\n size = font_size or max(round(sum(self.im.size) / 2 * 0.035), 12)\n self.font = ImageFont.truetype(str(font), size)\n except Exception:\n self.font = ImageFont.load_default()\n # Deprecation fix for w, h = getsize(string) -> _, _, w, h = getbox(string)\n if check_version(pil_version, \"9.2.0\"):\n self.font.getsize = lambda x: self.font.getbbox(x)[2:4] # text width, height\n else: # use cv2\n assert im.data.contiguous, \"Image not contiguous. Apply np.ascontiguousarray(im) to Annotator input images.\"\n self.im = im if im.flags.writeable else im.copy()\n self.tf = max(self.lw - 1, 1) # font thickness\n self.sf = self.lw / 3 # font scale\n # Pose\n self.skeleton = [\n [16, 14],\n [14, 12],\n [17, 15],\n [15, 13],\n [12, 13],\n [6, 12],\n [7, 13],\n [6, 7],\n [6, 8],\n [7, 9],\n [8, 10],\n [9, 11],\n [2, 3],\n [1, 2],\n [1, 3],\n [2, 4],\n [3, 5],\n [4, 6],\n [5, 7],\n ]\n\n self.limb_color = colors.pose_palette[[9, 9, 9, 9, 7, 7, 7, 0, 0, 0, 0, 0, 16, 16, 16, 16, 16, 16, 16]]\n self.kpt_color = colors.pose_palette[[16, 16, 16, 16, 16, 0, 0, 0, 0, 0, 0, 9, 9, 9, 9, 9, 9]]\n self.dark_colors = {\n (235, 219, 11),\n (243, 243, 243),\n (183, 223, 0),\n (221, 111, 255),\n (0, 237, 204),\n (68, 243, 0),\n (255, 255, 0),\n (179, 255, 1),\n (11, 255, 162),\n }\n self.light_colors = {\n (255, 42, 4),\n (79, 68, 255),\n (255, 0, 189),\n (255, 180, 0),\n (186, 0, 221),\n (0, 192, 38),\n (255, 36, 125),\n (104, 0, 123),\n (108, 27, 255),\n (47, 109, 252),\n (104, 31, 17),\n }\n\n def get_txt_color(self, color: tuple = (128, 128, 128), txt_color: tuple = (255, 255, 255)) -> tuple:\n \"\"\"\n Assign text color based on background color.\n\n Args:\n color (tuple, optional): The background color of the rectangle for text (B, G, R).\n txt_color (tuple, optional): The color of the text (R, G, B).\n\n Returns:\n (tuple): Text color for label.\n\n Examples:\n >>> from ultralytics.utils.plotting import Annotator\n >>> im0 = cv2.imread(\"test.png\")\n >>> annotator = Annotator(im0, line_width=10)\n >>> annotator.get_txt_color(color=(104, 31, 17)) # return (255, 255, 255)\n \"\"\"\n if color in self.dark_colors:\n return 104, 31, 17\n elif color in self.light_colors:\n return 255, 255, 255\n else:\n return txt_color\n\n def box_label(self, box, label: str = \"\", color: tuple = (128, 128, 128), txt_color: tuple = (255, 255, 255)):\n \"\"\"\n Draw a bounding box on an image with a given label.\n\n Args:\n box (tuple): The bounding box coordinates (x1, y1, x2, y2).\n label (str, optional): The text label to be displayed.\n color (tuple, optional): The background color of the rectangle (B, G, R).\n txt_color (tuple, optional): The color of the text (R, G, B).\n\n Examples:\n >>> from ultralytics.utils.plotting import Annotator\n >>> im0 = cv2.imread(\"test.png\")\n >>> annotator = Annotator(im0, line_width=10)\n >>> annotator.box_label(box=[10, 20, 30, 40], label=\"person\")\n \"\"\"\n txt_color = self.get_txt_color(color, txt_color)\n if isinstance(box, torch.Tensor):\n box = box.tolist()\n\n multi_points = isinstance(box[0], list) # multiple points with shape (n, 2)\n p1 = [int(b) for b in box[0]] if multi_points else (int(box[0]), int(box[1]))\n if self.pil:\n self.draw.polygon(\n [tuple(b) for b in box], width=self.lw, outline=color\n ) if multi_points else self.draw.rectangle(box, width=self.lw, outline=color)\n if label:\n w, h = self.font.getsize(label) # text width, height\n outside = p1[1] >= h # label fits outside box\n if p1[0] > self.im.size[0] - w: # size is (w, h), check if label extend beyond right side of image\n p1 = self.im.size[0] - w, p1[1]\n self.draw.rectangle(\n (p1[0], p1[1] - h if outside else p1[1], p1[0] + w + 1, p1[1] + 1 if outside else p1[1] + h + 1),\n fill=color,\n )\n # self.draw.text([box[0], box[1]], label, fill=txt_color, font=self.font, anchor='ls') # for PIL>8.0\n self.draw.text((p1[0], p1[1] - h if outside else p1[1]), label, fill=txt_color, font=self.font)\n else: # cv2\n cv2.polylines(\n self.im, [np.asarray(box, dtype=int)], True, color, self.lw\n ) if multi_points else cv2.rectangle(\n self.im, p1, (int(box[2]), int(box[3])), color, thickness=self.lw, lineType=cv2.LINE_AA\n )\n if label:\n w, h = cv2.getTextSize(label, 0, fontScale=self.sf, thickness=self.tf)[0] # text width, height\n h += 3 # add pixels to pad text\n outside = p1[1] >= h # label fits outside box\n if p1[0] > self.im.shape[1] - w: # shape is (h, w), check if label extend beyond right side of image\n p1 = self.im.shape[1] - w, p1[1]\n p2 = p1[0] + w, p1[1] - h if outside else p1[1] + h\n cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) # filled\n cv2.putText(\n self.im,\n label,\n (p1[0], p1[1] - 2 if outside else p1[1] + h - 1),\n 0,\n self.sf,\n txt_color,\n thickness=self.tf,\n lineType=cv2.LINE_AA,\n )\n\n def masks(self, masks, colors, im_gpu, alpha: float = 0.5, retina_masks: bool = False):\n \"\"\"\n Plot masks on image.\n\n Args:\n masks (torch.Tensor): Predicted masks on cuda, shape: [n, h, w]\n colors (List[List[int]]): Colors for predicted masks, [[r, g, b] * n]\n im_gpu (torch.Tensor): Image is in cuda, shape: [3, h, w], range: [0, 1]\n alpha (float, optional): Mask transparency: 0.0 fully transparent, 1.0 opaque.\n retina_masks (bool, optional): Whether to use high resolution masks or not.\n \"\"\"\n if self.pil:\n # Convert to numpy first\n self.im = np.asarray(self.im).copy()\n if len(masks) == 0:\n self.im[:] = im_gpu.permute(1, 2, 0).contiguous().cpu().numpy() * 255\n if im_gpu.device != masks.device:\n im_gpu = im_gpu.to(masks.device)\n colors = torch.tensor(colors, device=masks.device, dtype=torch.float32) / 255.0 # shape(n,3)\n colors = colors[:, None, None] # shape(n,1,1,3)\n masks = masks.unsqueeze(3) # shape(n,h,w,1)\n masks_color = masks * (colors * alpha) # shape(n,h,w,3)\n\n inv_alpha_masks = (1 - masks * alpha).cumprod(0) # shape(n,h,w,1)\n mcs = masks_color.max(dim=0).values # shape(n,h,w,3)\n\n im_gpu = im_gpu.flip(dims=[0]) # flip channel\n im_gpu = im_gpu.permute(1, 2, 0).contiguous() # shape(h,w,3)\n im_gpu = im_gpu * inv_alpha_masks[-1] + mcs\n im_mask = im_gpu * 255\n im_mask_np = im_mask.byte().cpu().numpy()\n self.im[:] = im_mask_np if retina_masks else ops.scale_image(im_mask_np, self.im.shape)\n if self.pil:\n # Convert im back to PIL and update draw\n self.fromarray(self.im)\n\n def kpts(\n self,\n kpts,\n shape: tuple = (640, 640),\n radius: Optional[int] = None,\n kpt_line: bool = True,\n conf_thres: float = 0.25,\n kpt_color: Optional[tuple] = None,\n ):\n \"\"\"\n Plot keypoints on the image.\n\n Args:\n kpts (torch.Tensor): Keypoints, shape [17, 3] (x, y, confidence).\n shape (tuple, optional): Image shape (h, w).\n radius (int, optional): Keypoint radius.\n kpt_line (bool, optional): Draw lines between keypoints.\n conf_thres (float, optional): Confidence threshold.\n kpt_color (tuple, optional): Keypoint color (B, G, R).\n\n Note:\n - `kpt_line=True` currently only supports human pose plotting.\n - Modifies self.im in-place.\n - If self.pil is True, converts image to numpy array and back to PIL.\n \"\"\"\n radius = radius if radius is not None else self.lw\n if self.pil:\n # Convert to numpy first\n self.im = np.asarray(self.im).copy()\n nkpt, ndim = kpts.shape\n is_pose = nkpt == 17 and ndim in {2, 3}\n kpt_line &= is_pose # `kpt_line=True` for now only supports human pose plotting\n for i, k in enumerate(kpts):\n color_k = kpt_color or (self.kpt_color[i].tolist() if is_pose else colors(i))\n x_coord, y_coord = k[0], k[1]\n if x_coord % shape[1] != 0 and y_coord % shape[0] != 0:\n if len(k) == 3:\n conf = k[2]\n if conf < conf_thres:\n continue\n cv2.circle(self.im, (int(x_coord), int(y_coord)), radius, color_k, -1, lineType=cv2.LINE_AA)\n\n if kpt_line:\n ndim = kpts.shape[-1]\n for i, sk in enumerate(self.skeleton):\n pos1 = (int(kpts[(sk[0] - 1), 0]), int(kpts[(sk[0] - 1), 1]))\n pos2 = (int(kpts[(sk[1] - 1), 0]), int(kpts[(sk[1] - 1), 1]))\n if ndim == 3:\n conf1 = kpts[(sk[0] - 1), 2]\n conf2 = kpts[(sk[1] - 1), 2]\n if conf1 < conf_thres or conf2 < conf_thres:\n continue\n if pos1[0] % shape[1] == 0 or pos1[1] % shape[0] == 0 or pos1[0] < 0 or pos1[1] < 0:\n continue\n if pos2[0] % shape[1] == 0 or pos2[1] % shape[0] == 0 or pos2[0] < 0 or pos2[1] < 0:\n continue\n cv2.line(\n self.im,\n pos1,\n pos2,\n kpt_color or self.limb_color[i].tolist(),\n thickness=int(np.ceil(self.lw / 2)),\n lineType=cv2.LINE_AA,\n )\n if self.pil:\n # Convert im back to PIL and update draw\n self.fromarray(self.im)\n\n def rectangle(self, xy, fill=None, outline=None, width: int = 1):\n \"\"\"Add rectangle to image (PIL-only).\"\"\"\n self.draw.rectangle(xy, fill, outline, width)\n\n def text(self, xy, text: str, txt_color: tuple = (255, 255, 255), anchor: str = \"top\", box_color: tuple = ()):\n \"\"\"\n Add text to an image using PIL or cv2.\n\n Args:\n xy (List[int]): Top-left coordinates for text placement.\n text (str): Text to be drawn.\n txt_color (tuple, optional): Text color (R, G, B).\n anchor (str, optional): Text anchor position ('top' or 'bottom').\n box_color (tuple, optional): Box color (R, G, B, A) with optional alpha.\n \"\"\"\n if self.pil:\n w, h = self.font.getsize(text)\n if anchor == \"bottom\": # start y from font bottom\n xy[1] += 1 - h\n for line in text.split(\"\\n\"):\n if box_color:\n # Draw rectangle for each line\n w, h = self.font.getsize(line)\n self.draw.rectangle((xy[0], xy[1], xy[0] + w + 1, xy[1] + h + 1), fill=box_color)\n self.draw.text(xy, line, fill=txt_color, font=self.font)\n xy[1] += h\n else:\n if box_color:\n w, h = cv2.getTextSize(text, 0, fontScale=self.sf, thickness=self.tf)[0]\n h += 3 # add pixels to pad text\n outside = xy[1] >= h # label fits outside box\n p2 = xy[0] + w, xy[1] - h if outside else xy[1] + h\n cv2.rectangle(self.im, xy, p2, box_color, -1, cv2.LINE_AA) # filled\n cv2.putText(self.im, text, xy, 0, self.sf, txt_color, thickness=self.tf, lineType=cv2.LINE_AA)\n\n def fromarray(self, im):\n \"\"\"Update self.im from a numpy array.\"\"\"\n self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)\n self.draw = ImageDraw.Draw(self.im)\n\n def result(self):\n \"\"\"Return annotated image as array.\"\"\"\n return np.asarray(self.im)\n\n def show(self, title: Optional[str] = None):\n \"\"\"Show the annotated image.\"\"\"\n im = Image.fromarray(np.asarray(self.im)[..., ::-1]) # Convert numpy array to PIL Image with RGB to BGR\n if IS_COLAB or IS_KAGGLE: # can not use IS_JUPYTER as will run for all ipython environments\n try:\n display(im) # noqa - display() function only available in ipython environments\n except ImportError as e:\n LOGGER.warning(f\"Unable to display image in Jupyter notebooks: {e}\")\n else:\n im.show(title=title)\n\n def save(self, filename: str = \"image.jpg\"):\n \"\"\"Save the annotated image to 'filename'.\"\"\"\n cv2.imwrite(filename, np.asarray(self.im))\n\n @staticmethod\n def get_bbox_dimension(bbox: Optional[tuple] = None):\n \"\"\"\n Calculate the dimensions and area of a bounding box.\n\n Args:\n bbox (tuple): Bounding box coordinates in the format (x_min, y_min, x_max, y_max).\n\n Returns:\n width (float): Width of the bounding box.\n height (float): Height of the bounding box.\n area (float): Area enclosed by the bounding box.\n\n Examples:\n >>> from ultralytics.utils.plotting import Annotator\n >>> im0 = cv2.imread(\"test.png\")\n >>> annotator = Annotator(im0, line_width=10)\n >>> annotator.get_bbox_dimension(bbox=[10, 20, 30, 40])\n \"\"\"\n x_min, y_min, x_max, y_max = bbox\n width = x_max - x_min\n height = y_max - y_min\n return width, height, width * height", "chunk_type": "class", "name": "Annotator", "file_path": "ultralytics\\ultralytics\\utils\\plotting.py", "start_line": 168, "end_line": 549, "start_col": 0, "end_col": 44, "parent_name": null, "docstring": "Ultralytics Annotator for train/val mosaics and JPGs and predictions annotations.\n\nAttributes:\n im (Image.Image | np.ndarray): The image to annotate.\n pil (bool): Whether to use PIL or cv2 for drawing annotations.\n font (ImageFont.truetype | ImageFont.load_default): Font used for text annotations.\n lw (float): Line width for drawing.\n skeleton (List[List[int]]): Skeleton structure for keypoints.\n limb_color (List[int]): Color palette for limbs.\n kpt_color (List[int]): Color palette for keypoints.\n dark_colors (set): Set of colors considered dark for text contrast.\n light_colors (set): Set of colors considered light for text contrast.\n\nExamples:\n >>> from ultralytics.utils.plotting import Annotator\n >>> im0 = cv2.imread(\"test.png\")\n >>> annotator = Annotator(im0, line_width=10)\n >>> annotator.box_label([10, 10, 100, 100], \"person\", (255, 0, 0))", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "math", "warnings", "pathlib.Path", "typing.Any", "typing.Callable", "typing.Dict", "typing.List", "typing.Optional", "typing.Union", "cv2", "numpy", "torch", "PIL.Image", "PIL.ImageDraw", "PIL.ImageFont", "PIL.__version__", "ultralytics.utils.IS_COLAB", "ultralytics.utils.IS_KAGGLE", "ultralytics.utils.LOGGER", "ultralytics.utils.TryExcept", "ultralytics.utils.ops", "ultralytics.utils.plt_settings", "ultralytics.utils.threaded", "ultralytics.utils.checks.check_font", "ultralytics.utils.checks.check_version", "ultralytics.utils.checks.is_ascii", "ultralytics.utils.files.increment_path", "matplotlib.pyplot", "pandas", "matplotlib.colors.LinearSegmentedColormap", "matplotlib.pyplot", "pandas", "scipy.ndimage.gaussian_filter1d", "matplotlib.pyplot", "matplotlib.pyplot", "pandas", "scipy.ndimage.gaussian_filter1d", "matplotlib.pyplot", "seaborn" ], "chunk_id": "class_Annotator_82473fd1" }, { "content": "def plot_labels(boxes, cls, names=(), save_dir=Path(\"\"), on_plot=None):\n \"\"\"\n Plot training labels including class histograms and box statistics.\n\n Args:\n boxes (np.ndarray): Bounding box coordinates in format [x, y, width, height].\n cls (np.ndarray): Class indices.\n names (dict, optional): Dictionary mapping class indices to class names.\n save_dir (Path, optional): Directory to save the plot.\n on_plot (Callable, optional): Function to call after plot is saved.\n \"\"\"\n import matplotlib.pyplot as plt # scope for faster 'import ultralytics'\n import pandas\n from matplotlib.colors import LinearSegmentedColormap\n\n # Filter matplotlib>=3.7.2 warning\n warnings.filterwarnings(\"ignore\", category=UserWarning, message=\"The figure layout has changed to tight\")\n warnings.filterwarnings(\"ignore\", category=FutureWarning)\n\n # Plot dataset labels\n LOGGER.info(f\"Plotting labels to {save_dir / 'labels.jpg'}... \")\n nc = int(cls.max() + 1) # number of classes\n boxes = boxes[:1000000] # limit to 1M boxes\n x = pandas.DataFrame(boxes, columns=[\"x\", \"y\", \"width\", \"height\"])\n\n try: # Seaborn correlogram\n import seaborn\n\n seaborn.pairplot(x, corner=True, diag_kind=\"auto\", kind=\"hist\", diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))\n plt.savefig(save_dir / \"labels_correlogram.jpg\", dpi=200)\n plt.close()\n except ImportError:\n pass # Skip if seaborn is not installed\n\n # Matplotlib labels\n subplot_3_4_color = LinearSegmentedColormap.from_list(\"white_blue\", [\"white\", \"blue\"])\n ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()\n y = ax[0].hist(cls, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)\n for i in range(nc):\n y[2].patches[i].set_color([x / 255 for x in colors(i)])\n ax[0].set_ylabel(\"instances\")\n if 0 < len(names) < 30:\n ax[0].set_xticks(range(len(names)))\n ax[0].set_xticklabels(list(names.values()), rotation=90, fontsize=10)\n else:\n ax[0].set_xlabel(\"classes\")\n boxes = np.column_stack([0.5 - boxes[:, 2:4] / 2, 0.5 + boxes[:, 2:4] / 2]) * 1000\n img = Image.fromarray(np.ones((1000, 1000, 3), dtype=np.uint8) * 255)\n for cls, box in zip(cls[:500], boxes[:500]):\n ImageDraw.Draw(img).rectangle(box, width=1, outline=colors(cls)) # plot\n ax[1].imshow(img)\n ax[1].axis(\"off\")\n\n ax[2].hist2d(x[\"x\"], x[\"y\"], bins=50, cmap=subplot_3_4_color)\n ax[2].set_xlabel(\"x\")\n ax[2].set_ylabel(\"y\")\n ax[3].hist2d(x[\"width\"], x[\"height\"], bins=50, cmap=subplot_3_4_color)\n ax[3].set_xlabel(\"width\")\n ax[3].set_ylabel(\"height\")\n for a in {0, 1, 2, 3}:\n for s in {\"top\", \"right\", \"left\", \"bottom\"}:\n ax[a].spines[s].set_visible(False)\n\n fname = save_dir / \"labels.jpg\"\n plt.savefig(fname, dpi=200)\n plt.close()\n if on_plot:\n on_plot(fname)", "chunk_type": "function", "name": "plot_labels", "file_path": "ultralytics\\ultralytics\\utils\\plotting.py", "start_line": 554, "end_line": 621, "start_col": 0, "end_col": 22, "parent_name": null, "docstring": "Plot training labels including class histograms and box statistics.\n\nArgs:\n boxes (np.ndarray): Bounding box coordinates in format [x, y, width, height].\n cls (np.ndarray): Class indices.\n names (dict, optional): Dictionary mapping class indices to class names.\n save_dir (Path, optional): Directory to save the plot.\n on_plot (Callable, optional): Function to call after plot is saved.", "parameters": [ "boxes", "cls", "names", "save_dir", "on_plot" ], "return_type": null, "decorators": [ "TryExcept()", "plt_settings()" ], "complexity_score": 9, "dependencies": [ "math", "warnings", "pathlib.Path", "typing.Any", "typing.Callable", "typing.Dict", "typing.List", "typing.Optional", "typing.Union", "cv2", "numpy", "torch", "PIL.Image", "PIL.ImageDraw", "PIL.ImageFont", "PIL.__version__", "ultralytics.utils.IS_COLAB", "ultralytics.utils.IS_KAGGLE", "ultralytics.utils.LOGGER", "ultralytics.utils.TryExcept", "ultralytics.utils.ops", "ultralytics.utils.plt_settings", "ultralytics.utils.threaded", "ultralytics.utils.checks.check_font", "ultralytics.utils.checks.check_version", "ultralytics.utils.checks.is_ascii", "ultralytics.utils.files.increment_path", "matplotlib.pyplot", "pandas", "matplotlib.colors.LinearSegmentedColormap", "matplotlib.pyplot", "pandas", "scipy.ndimage.gaussian_filter1d", "matplotlib.pyplot", "matplotlib.pyplot", "pandas", "scipy.ndimage.gaussian_filter1d", "matplotlib.pyplot", "seaborn" ], "chunk_id": "function_plot_labels_0423c306" }, { "content": "def save_one_box(\n xyxy,\n im,\n file: Path = Path(\"im.jpg\"),\n gain: float = 1.02,\n pad: int = 10,\n square: bool = False,\n BGR: bool = False,\n save: bool = True,\n):\n \"\"\"\n Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop.\n\n This function takes a bounding box and an image, and then saves a cropped portion of the image according\n to the bounding box. Optionally, the crop can be squared, and the function allows for gain and padding\n adjustments to the bounding box.\n\n Args:\n xyxy (torch.Tensor | list): A tensor or list representing the bounding box in xyxy format.\n im (np.ndarray): The input image.\n file (Path, optional): The path where the cropped image will be saved.\n gain (float, optional): A multiplicative factor to increase the size of the bounding box.\n pad (int, optional): The number of pixels to add to the width and height of the bounding box.\n square (bool, optional): If True, the bounding box will be transformed into a square.\n BGR (bool, optional): If True, the image will be returned in BGR format, otherwise in RGB.\n save (bool, optional): If True, the cropped image will be saved to disk.\n\n Returns:\n (np.ndarray): The cropped image.\n\n Examples:\n >>> from ultralytics.utils.plotting import save_one_box\n >>> xyxy = [50, 50, 150, 150]\n >>> im = cv2.imread(\"image.jpg\")\n >>> cropped_im = save_one_box(xyxy, im, file=\"cropped.jpg\", square=True)\n \"\"\"\n if not isinstance(xyxy, torch.Tensor): # may be list\n xyxy = torch.stack(xyxy)\n b = ops.xyxy2xywh(xyxy.view(-1, 4)) # boxes\n if square:\n b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # attempt rectangle to square\n b[:, 2:] = b[:, 2:] * gain + pad # box wh * gain + pad\n xyxy = ops.xywh2xyxy(b).long()\n xyxy = ops.clip_boxes(xyxy, im.shape)\n grayscale = im.shape[2] == 1 # grayscale image\n crop = im[int(xyxy[0, 1]) : int(xyxy[0, 3]), int(xyxy[0, 0]) : int(xyxy[0, 2]), :: (1 if BGR or grayscale else -1)]\n if save:\n file.parent.mkdir(parents=True, exist_ok=True) # make directory\n f = str(increment_path(file).with_suffix(\".jpg\"))\n # cv2.imwrite(f, crop) # save BGR, https://github.com/ultralytics/yolov5/issues/7007 chroma subsampling issue\n crop = crop.squeeze(-1) if grayscale else crop[..., ::-1] if BGR else crop\n Image.fromarray(crop).save(f, quality=95, subsampling=0) # save RGB\n return crop", "chunk_type": "function", "name": "save_one_box", "file_path": "ultralytics\\ultralytics\\utils\\plotting.py", "start_line": 624, "end_line": 676, "start_col": 0, "end_col": 15, "parent_name": null, "docstring": "Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop.\n\nThis function takes a bounding box and an image, and then saves a cropped portion of the image according\nto the bounding box. Optionally, the crop can be squared, and the function allows for gain and padding\nadjustments to the bounding box.\n\nArgs:\n xyxy (torch.Tensor | list): A tensor or list representing the bounding box in xyxy format.\n im (np.ndarray): The input image.\n file (Path, optional): The path where the cropped image will be saved.\n gain (float, optional): A multiplicative factor to increase the size of the bounding box.\n pad (int, optional): The number of pixels to add to the width and height of the bounding box.\n square (bool, optional): If True, the bounding box will be transformed into a square.\n BGR (bool, optional): If True, the image will be returned in BGR format, otherwise in RGB.\n save (bool, optional): If True, the cropped image will be saved to disk.\n\nReturns:\n (np.ndarray): The cropped image.\n\nExamples:\n >>> from ultralytics.utils.plotting import save_one_box\n >>> xyxy = [50, 50, 150, 150]\n >>> im = cv2.imread(\"image.jpg\")\n >>> cropped_im = save_one_box(xyxy, im, file=\"cropped.jpg\", square=True)", "parameters": [ "xyxy", "im", "file: Path", "gain: float", "pad: int", "square: bool", "BGR: bool", "save: bool" ], "return_type": null, "decorators": [], "complexity_score": 4, "dependencies": [ "math", "warnings", "pathlib.Path", "typing.Any", "typing.Callable", "typing.Dict", "typing.List", "typing.Optional", "typing.Union", "cv2", "numpy", "torch", "PIL.Image", "PIL.ImageDraw", "PIL.ImageFont", "PIL.__version__", "ultralytics.utils.IS_COLAB", "ultralytics.utils.IS_KAGGLE", "ultralytics.utils.LOGGER", "ultralytics.utils.TryExcept", "ultralytics.utils.ops", "ultralytics.utils.plt_settings", "ultralytics.utils.threaded", "ultralytics.utils.checks.check_font", "ultralytics.utils.checks.check_version", "ultralytics.utils.checks.is_ascii", "ultralytics.utils.files.increment_path", "matplotlib.pyplot", "pandas", "matplotlib.colors.LinearSegmentedColormap", "matplotlib.pyplot", "pandas", "scipy.ndimage.gaussian_filter1d", "matplotlib.pyplot", "matplotlib.pyplot", "pandas", "scipy.ndimage.gaussian_filter1d", "matplotlib.pyplot", "seaborn" ], "chunk_id": "function_save_one_box_081a27c2" }, { "content": "def plot_images(\n labels: Dict[str, Any],\n images: Union[torch.Tensor, np.ndarray] = np.zeros((0, 3, 640, 640), dtype=np.float32),\n paths: Optional[List[str]] = None,\n fname: str = \"images.jpg\",\n names: Optional[Dict[int, str]] = None,\n on_plot: Optional[Callable] = None,\n max_size: int = 1920,\n max_subplots: int = 16,\n save: bool = True,\n conf_thres: float = 0.25,\n) -> Optional[np.ndarray]:\n \"\"\"\n Plot image grid with labels, bounding boxes, masks, and keypoints.\n\n Args:\n labels (Dict[str, Any]): Dictionary containing detection data with keys like 'cls', 'bboxes', 'conf', 'masks', 'keypoints', 'batch_idx', 'img'.\n images (torch.Tensor | np.ndarray]): Batch of images to plot. Shape: (batch_size, channels, height, width).\n paths (Optional[List[str]]): List of file paths for each image in the batch.\n fname (str): Output filename for the plotted image grid.\n names (Optional[Dict[int, str]]): Dictionary mapping class indices to class names.\n on_plot (Optional[Callable]): Optional callback function to be called after saving the plot.\n max_size (int): Maximum size of the output image grid.\n max_subplots (int): Maximum number of subplots in the image grid.\n save (bool): Whether to save the plotted image grid to a file.\n conf_thres (float): Confidence threshold for displaying detections.\n\n Returns:\n (np.ndarray): Plotted image grid as a numpy array if save is False, None otherwise.\n\n Note:\n This function supports both tensor and numpy array inputs. It will automatically\n convert tensor inputs to numpy arrays for processing.\n \"\"\"\n for k in {\"cls\", \"bboxes\", \"conf\", \"masks\", \"keypoints\", \"batch_idx\", \"images\"}:\n if k not in labels:\n continue\n if k == \"cls\" and labels[k].ndim == 2:\n labels[k] = labels[k].squeeze(1) # squeeze if shape is (n, 1)\n if isinstance(labels[k], torch.Tensor):\n labels[k] = labels[k].cpu().numpy()\n\n cls = labels.get(\"cls\", np.zeros(0, dtype=np.int64))\n batch_idx = labels.get(\"batch_idx\", np.zeros(cls.shape, dtype=np.int64))\n bboxes = labels.get(\"bboxes\", np.zeros(0, dtype=np.float32))\n confs = labels.get(\"conf\", None)\n masks = labels.get(\"masks\", np.zeros(0, dtype=np.uint8))\n kpts = labels.get(\"keypoints\", np.zeros(0, dtype=np.float32))\n images = labels.get(\"img\", images) # default to input images\n\n if len(images) and isinstance(images, torch.Tensor):\n images = images.cpu().float().numpy()\n if images.shape[1] > 3:\n images = images[:, :3] # crop multispectral images to first 3 channels\n\n bs, _, h, w = images.shape # batch size, _, height, width\n bs = min(bs, max_subplots) # limit plot images\n ns = np.ceil(bs**0.5) # number of subplots (square)\n if np.max(images[0]) <= 1:\n images *= 255 # de-normalise (optional)\n\n # Build Image\n mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init\n for i in range(bs):\n x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin\n mosaic[y : y + h, x : x + w, :] = images[i].transpose(1, 2, 0)\n\n # Resize (optional)\n scale = max_size / ns / max(h, w)\n if scale < 1:\n h = math.ceil(scale * h)\n w = math.ceil(scale * w)\n mosaic = cv2.resize(mosaic, tuple(int(x * ns) for x in (w, h)))\n\n # Annotate\n fs = int((h + w) * ns * 0.01) # font size\n fs = max(fs, 18) # ensure that the font size is large enough to be easily readable.\n annotator = Annotator(mosaic, line_width=round(fs / 10), font_size=fs, pil=True, example=str(names))\n for i in range(bs):\n x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin\n annotator.rectangle([x, y, x + w, y + h], None, (255, 255, 255), width=2) # borders\n if paths:\n annotator.text([x + 5, y + 5], text=Path(paths[i]).name[:40], txt_color=(220, 220, 220)) # filenames\n if len(cls) > 0:\n idx = batch_idx == i\n classes = cls[idx].astype(\"int\")\n labels = confs is None\n\n if len(bboxes):\n boxes = bboxes[idx]\n conf = confs[idx] if confs is not None else None # check for confidence presence (label vs pred)\n if len(boxes):\n if boxes[:, :4].max() <= 1.1: # if normalized with tolerance 0.1\n boxes[..., [0, 2]] *= w # scale to pixels\n boxes[..., [1, 3]] *= h\n elif scale < 1: # absolute coords need scale if image scales\n boxes[..., :4] *= scale\n boxes[..., 0] += x\n boxes[..., 1] += y\n is_obb = boxes.shape[-1] == 5 # xywhr\n # TODO: this transformation might be unnecessary\n boxes = ops.xywhr2xyxyxyxy(boxes) if is_obb else ops.xywh2xyxy(boxes)\n for j, box in enumerate(boxes.astype(np.int64).tolist()):\n c = classes[j]\n color = colors(c)\n c = names.get(c, c) if names else c\n if labels or conf[j] > conf_thres:\n label = f\"{c}\" if labels else f\"{c} {conf[j]:.1f}\"\n annotator.box_label(box, label, color=color)\n\n elif len(classes):\n for c in classes:\n color = colors(c)\n c = names.get(c, c) if names else c\n annotator.text([x, y], f\"{c}\", txt_color=color, box_color=(64, 64, 64, 128))\n\n # Plot keypoints\n if len(kpts):\n kpts_ = kpts[idx].copy()\n if len(kpts_):\n if kpts_[..., 0].max() <= 1.01 or kpts_[..., 1].max() <= 1.01: # if normalized with tolerance .01\n kpts_[..., 0] *= w # scale to pixels\n kpts_[..., 1] *= h\n elif scale < 1: # absolute coords need scale if image scales\n kpts_ *= scale\n kpts_[..., 0] += x\n kpts_[..., 1] += y\n for j in range(len(kpts_)):\n if labels or conf[j] > conf_thres:\n annotator.kpts(kpts_[j], conf_thres=conf_thres)\n\n # Plot masks\n if len(masks):\n if idx.shape[0] == masks.shape[0]: # overlap_masks=False\n image_masks = masks[idx]\n else: # overlap_masks=True\n image_masks = masks[[i]] # (1, 640, 640)\n nl = idx.sum()\n index = np.arange(nl).reshape((nl, 1, 1)) + 1\n image_masks = np.repeat(image_masks, nl, axis=0)\n image_masks = np.where(image_masks == index, 1.0, 0.0)\n\n im = np.asarray(annotator.im).copy()\n for j in range(len(image_masks)):\n if labels or conf[j] > conf_thres:\n color = colors(classes[j])\n mh, mw = image_masks[j].shape\n if mh != h or mw != w:\n mask = image_masks[j].astype(np.uint8)\n mask = cv2.resize(mask, (w, h))\n mask = mask.astype(bool)\n else:\n mask = image_masks[j].astype(bool)\n try:\n im[y : y + h, x : x + w, :][mask] = (\n im[y : y + h, x : x + w, :][mask] * 0.4 + np.array(color) * 0.6\n )\n except Exception:\n pass\n annotator.fromarray(im)\n if not save:\n return np.asarray(annotator.im)\n annotator.im.save(fname) # save\n if on_plot:\n on_plot(fname)", "chunk_type": "function", "name": "plot_images", "file_path": "ultralytics\\ultralytics\\utils\\plotting.py", "start_line": 680, "end_line": 844, "start_col": 0, "end_col": 22, "parent_name": null, "docstring": "Plot image grid with labels, bounding boxes, masks, and keypoints.\n\nArgs:\n labels (Dict[str, Any]): Dictionary containing detection data with keys like 'cls', 'bboxes', 'conf', 'masks', 'keypoints', 'batch_idx', 'img'.\n images (torch.Tensor | np.ndarray]): Batch of images to plot. Shape: (batch_size, channels, height, width).\n paths (Optional[List[str]]): List of file paths for each image in the batch.\n fname (str): Output filename for the plotted image grid.\n names (Optional[Dict[int, str]]): Dictionary mapping class indices to class names.\n on_plot (Optional[Callable]): Optional callback function to be called after saving the plot.\n max_size (int): Maximum size of the output image grid.\n max_subplots (int): Maximum number of subplots in the image grid.\n save (bool): Whether to save the plotted image grid to a file.\n conf_thres (float): Confidence threshold for displaying detections.\n\nReturns:\n (np.ndarray): Plotted image grid as a numpy array if save is False, None otherwise.\n\nNote:\n This function supports both tensor and numpy array inputs. It will automatically\n convert tensor inputs to numpy arrays for processing.", "parameters": [ "labels: Dict[str, Any]", "images: Union[torch.Tensor, np.ndarray]", "paths: Optional[List[str]]", "fname: str", "names: Optional[Dict[int, str]]", "on_plot: Optional[Callable]", "max_size: int", "max_subplots: int", "save: bool", "conf_thres: float" ], "return_type": "Optional[np.ndarray]", "decorators": [ "threaded" ], "complexity_score": 36, "dependencies": [ "math", "warnings", "pathlib.Path", "typing.Any", "typing.Callable", "typing.Dict", "typing.List", "typing.Optional", "typing.Union", "cv2", "numpy", "torch", "PIL.Image", "PIL.ImageDraw", "PIL.ImageFont", "PIL.__version__", "ultralytics.utils.IS_COLAB", "ultralytics.utils.IS_KAGGLE", "ultralytics.utils.LOGGER", "ultralytics.utils.TryExcept", "ultralytics.utils.ops", "ultralytics.utils.plt_settings", "ultralytics.utils.threaded", "ultralytics.utils.checks.check_font", "ultralytics.utils.checks.check_version", "ultralytics.utils.checks.is_ascii", "ultralytics.utils.files.increment_path", "matplotlib.pyplot", "pandas", "matplotlib.colors.LinearSegmentedColormap", "matplotlib.pyplot", "pandas", "scipy.ndimage.gaussian_filter1d", "matplotlib.pyplot", "matplotlib.pyplot", "pandas", "scipy.ndimage.gaussian_filter1d", "matplotlib.pyplot", "seaborn" ], "chunk_id": "function_plot_images_4ec7aab8" }, { "content": "def plot_results(\n file: str = \"path/to/results.csv\",\n dir: str = \"\",\n segment: bool = False,\n pose: bool = False,\n classify: bool = False,\n on_plot: Optional[Callable] = None,\n):\n \"\"\"\n Plot training results from a results CSV file. The function supports various types of data including segmentation,\n pose estimation, and classification. Plots are saved as 'results.png' in the directory where the CSV is located.\n\n Args:\n file (str, optional): Path to the CSV file containing the training results.\n dir (str, optional): Directory where the CSV file is located if 'file' is not provided.\n segment (bool, optional): Flag to indicate if the data is for segmentation.\n pose (bool, optional): Flag to indicate if the data is for pose estimation.\n classify (bool, optional): Flag to indicate if the data is for classification.\n on_plot (callable, optional): Callback function to be executed after plotting. Takes filename as an argument.\n\n Examples:\n >>> from ultralytics.utils.plotting import plot_results\n >>> plot_results(\"path/to/results.csv\", segment=True)\n \"\"\"\n import matplotlib.pyplot as plt # scope for faster 'import ultralytics'\n import pandas as pd\n from scipy.ndimage import gaussian_filter1d\n\n save_dir = Path(file).parent if file else Path(dir)\n if classify:\n fig, ax = plt.subplots(2, 2, figsize=(6, 6), tight_layout=True)\n index = [2, 5, 3, 4]\n elif segment:\n fig, ax = plt.subplots(2, 8, figsize=(18, 6), tight_layout=True)\n index = [2, 3, 4, 5, 6, 7, 10, 11, 14, 15, 16, 17, 8, 9, 12, 13]\n elif pose:\n fig, ax = plt.subplots(2, 9, figsize=(21, 6), tight_layout=True)\n index = [2, 3, 4, 5, 6, 7, 8, 11, 12, 15, 16, 17, 18, 19, 9, 10, 13, 14]\n else:\n fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True)\n index = [2, 3, 4, 5, 6, 9, 10, 11, 7, 8]\n ax = ax.ravel()\n files = list(save_dir.glob(\"results*.csv\"))\n assert len(files), f\"No results.csv files found in {save_dir.resolve()}, nothing to plot.\"\n for f in files:\n try:\n data = pd.read_csv(f)\n s = [x.strip() for x in data.columns]\n x = data.values[:, 0]\n for i, j in enumerate(index):\n y = data.values[:, j].astype(\"float\")\n # y[y == 0] = np.nan # don't show zero values\n ax[i].plot(x, y, marker=\".\", label=f.stem, linewidth=2, markersize=8) # actual results\n ax[i].plot(x, gaussian_filter1d(y, sigma=3), \":\", label=\"smooth\", linewidth=2) # smoothing line\n ax[i].set_title(s[j], fontsize=12)\n # if j in {8, 9, 10}: # share train and val loss y axes\n # ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])\n except Exception as e:\n LOGGER.error(f\"Plotting error for {f}: {e}\")\n ax[1].legend()\n fname = save_dir / \"results.png\"\n fig.savefig(fname, dpi=200)\n plt.close()\n if on_plot:\n on_plot(fname)", "chunk_type": "function", "name": "plot_results", "file_path": "ultralytics\\ultralytics\\utils\\plotting.py", "start_line": 848, "end_line": 912, "start_col": 0, "end_col": 22, "parent_name": null, "docstring": "Plot training results from a results CSV file. The function supports various types of data including segmentation,\npose estimation, and classification. Plots are saved as 'results.png' in the directory where the CSV is located.\n\nArgs:\n file (str, optional): Path to the CSV file containing the training results.\n dir (str, optional): Directory where the CSV file is located if 'file' is not provided.\n segment (bool, optional): Flag to indicate if the data is for segmentation.\n pose (bool, optional): Flag to indicate if the data is for pose estimation.\n classify (bool, optional): Flag to indicate if the data is for classification.\n on_plot (callable, optional): Callback function to be executed after plotting. Takes filename as an argument.\n\nExamples:\n >>> from ultralytics.utils.plotting import plot_results\n >>> plot_results(\"path/to/results.csv\", segment=True)", "parameters": [ "file: str", "dir: str", "segment: bool", "pose: bool", "classify: bool", "on_plot: Optional[Callable]" ], "return_type": null, "decorators": [ "plt_settings()" ], "complexity_score": 9, "dependencies": [ "math", "warnings", "pathlib.Path", "typing.Any", "typing.Callable", "typing.Dict", "typing.List", "typing.Optional", "typing.Union", "cv2", "numpy", "torch", "PIL.Image", "PIL.ImageDraw", "PIL.ImageFont", "PIL.__version__", "ultralytics.utils.IS_COLAB", "ultralytics.utils.IS_KAGGLE", "ultralytics.utils.LOGGER", "ultralytics.utils.TryExcept", "ultralytics.utils.ops", "ultralytics.utils.plt_settings", "ultralytics.utils.threaded", "ultralytics.utils.checks.check_font", "ultralytics.utils.checks.check_version", "ultralytics.utils.checks.is_ascii", "ultralytics.utils.files.increment_path", "matplotlib.pyplot", "pandas", "matplotlib.colors.LinearSegmentedColormap", "matplotlib.pyplot", "pandas", "scipy.ndimage.gaussian_filter1d", "matplotlib.pyplot", "matplotlib.pyplot", "pandas", "scipy.ndimage.gaussian_filter1d", "matplotlib.pyplot", "seaborn" ], "chunk_id": "function_plot_results_72e5b78e" }, { "content": "def plt_color_scatter(v, f, bins: int = 20, cmap: str = \"viridis\", alpha: float = 0.8, edgecolors: str = \"none\"):\n \"\"\"\n Plot a scatter plot with points colored based on a 2D histogram.\n\n Args:\n v (array-like): Values for the x-axis.\n f (array-like): Values for the y-axis.\n bins (int, optional): Number of bins for the histogram.\n cmap (str, optional): Colormap for the scatter plot.\n alpha (float, optional): Alpha for the scatter plot.\n edgecolors (str, optional): Edge colors for the scatter plot.\n\n Examples:\n >>> v = np.random.rand(100)\n >>> f = np.random.rand(100)\n >>> plt_color_scatter(v, f)\n \"\"\"\n import matplotlib.pyplot as plt # scope for faster 'import ultralytics'\n\n # Calculate 2D histogram and corresponding colors\n hist, xedges, yedges = np.histogram2d(v, f, bins=bins)\n colors = [\n hist[\n min(np.digitize(v[i], xedges, right=True) - 1, hist.shape[0] - 1),\n min(np.digitize(f[i], yedges, right=True) - 1, hist.shape[1] - 1),\n ]\n for i in range(len(v))\n ]\n\n # Scatter plot\n plt.scatter(v, f, c=colors, cmap=cmap, alpha=alpha, edgecolors=edgecolors)", "chunk_type": "function", "name": "plt_color_scatter", "file_path": "ultralytics\\ultralytics\\utils\\plotting.py", "start_line": 915, "end_line": 945, "start_col": 0, "end_col": 78, "parent_name": null, "docstring": "Plot a scatter plot with points colored based on a 2D histogram.\n\nArgs:\n v (array-like): Values for the x-axis.\n f (array-like): Values for the y-axis.\n bins (int, optional): Number of bins for the histogram.\n cmap (str, optional): Colormap for the scatter plot.\n alpha (float, optional): Alpha for the scatter plot.\n edgecolors (str, optional): Edge colors for the scatter plot.\n\nExamples:\n >>> v = np.random.rand(100)\n >>> f = np.random.rand(100)\n >>> plt_color_scatter(v, f)", "parameters": [ "v", "f", "bins: int", "cmap: str", "alpha: float", "edgecolors: str" ], "return_type": null, "decorators": [], "complexity_score": 2, "dependencies": [ "math", "warnings", "pathlib.Path", "typing.Any", "typing.Callable", "typing.Dict", "typing.List", "typing.Optional", "typing.Union", "cv2", "numpy", "torch", "PIL.Image", "PIL.ImageDraw", "PIL.ImageFont", "PIL.__version__", "ultralytics.utils.IS_COLAB", "ultralytics.utils.IS_KAGGLE", "ultralytics.utils.LOGGER", "ultralytics.utils.TryExcept", "ultralytics.utils.ops", "ultralytics.utils.plt_settings", "ultralytics.utils.threaded", "ultralytics.utils.checks.check_font", "ultralytics.utils.checks.check_version", "ultralytics.utils.checks.is_ascii", "ultralytics.utils.files.increment_path", "matplotlib.pyplot", "pandas", "matplotlib.colors.LinearSegmentedColormap", "matplotlib.pyplot", "pandas", "scipy.ndimage.gaussian_filter1d", "matplotlib.pyplot", "matplotlib.pyplot", "pandas", "scipy.ndimage.gaussian_filter1d", "matplotlib.pyplot", "seaborn" ], "chunk_id": "function_plt_color_scatter_95489aff" }, { "content": "def plot_tune_results(csv_file: str = \"tune_results.csv\"):\n \"\"\"\n Plot the evolution results stored in a 'tune_results.csv' file. The function generates a scatter plot for each key\n in the CSV, color-coded based on fitness scores. The best-performing configurations are highlighted on the plots.\n\n Args:\n csv_file (str, optional): Path to the CSV file containing the tuning results.\n\n Examples:\n >>> plot_tune_results(\"path/to/tune_results.csv\")\n \"\"\"\n import matplotlib.pyplot as plt # scope for faster 'import ultralytics'\n import pandas as pd\n from scipy.ndimage import gaussian_filter1d\n\n def _save_one_file(file):\n \"\"\"Save one matplotlib plot to 'file'.\"\"\"\n plt.savefig(file, dpi=200)\n plt.close()\n LOGGER.info(f\"Saved {file}\")\n\n # Scatter plots for each hyperparameter\n csv_file = Path(csv_file)\n data = pd.read_csv(csv_file)\n num_metrics_columns = 1\n keys = [x.strip() for x in data.columns][num_metrics_columns:]\n x = data.values\n fitness = x[:, 0] # fitness\n j = np.argmax(fitness) # max fitness index\n n = math.ceil(len(keys) ** 0.5) # columns and rows in plot\n plt.figure(figsize=(10, 10), tight_layout=True)\n for i, k in enumerate(keys):\n v = x[:, i + num_metrics_columns]\n mu = v[j] # best single result\n plt.subplot(n, n, i + 1)\n plt_color_scatter(v, fitness, cmap=\"viridis\", alpha=0.8, edgecolors=\"none\")\n plt.plot(mu, fitness.max(), \"k+\", markersize=15)\n plt.title(f\"{k} = {mu:.3g}\", fontdict={\"size\": 9}) # limit to 40 characters\n plt.tick_params(axis=\"both\", labelsize=8) # Set axis label size to 8\n if i % n != 0:\n plt.yticks([])\n _save_one_file(csv_file.with_name(\"tune_scatter_plots.png\"))\n\n # Fitness vs iteration\n x = range(1, len(fitness) + 1)\n plt.figure(figsize=(10, 6), tight_layout=True)\n plt.plot(x, fitness, marker=\"o\", linestyle=\"none\", label=\"fitness\")\n plt.plot(x, gaussian_filter1d(fitness, sigma=3), \":\", label=\"smoothed\", linewidth=2) # smoothing line\n plt.title(\"Fitness vs Iteration\")\n plt.xlabel(\"Iteration\")\n plt.ylabel(\"Fitness\")\n plt.grid(True)\n plt.legend()\n _save_one_file(csv_file.with_name(\"tune_fitness.png\"))", "chunk_type": "function", "name": "plot_tune_results", "file_path": "ultralytics\\ultralytics\\utils\\plotting.py", "start_line": 948, "end_line": 1001, "start_col": 0, "end_col": 58, "parent_name": null, "docstring": "Plot the evolution results stored in a 'tune_results.csv' file. The function generates a scatter plot for each key\nin the CSV, color-coded based on fitness scores. The best-performing configurations are highlighted on the plots.\n\nArgs:\n csv_file (str, optional): Path to the CSV file containing the tuning results.\n\nExamples:\n >>> plot_tune_results(\"path/to/tune_results.csv\")", "parameters": [ "csv_file: str" ], "return_type": null, "decorators": [], "complexity_score": 4, "dependencies": [ "math", "warnings", "pathlib.Path", "typing.Any", "typing.Callable", "typing.Dict", "typing.List", "typing.Optional", "typing.Union", "cv2", "numpy", "torch", "PIL.Image", "PIL.ImageDraw", "PIL.ImageFont", "PIL.__version__", "ultralytics.utils.IS_COLAB", "ultralytics.utils.IS_KAGGLE", "ultralytics.utils.LOGGER", "ultralytics.utils.TryExcept", "ultralytics.utils.ops", "ultralytics.utils.plt_settings", "ultralytics.utils.threaded", "ultralytics.utils.checks.check_font", "ultralytics.utils.checks.check_version", "ultralytics.utils.checks.is_ascii", "ultralytics.utils.files.increment_path", "matplotlib.pyplot", "pandas", "matplotlib.colors.LinearSegmentedColormap", "matplotlib.pyplot", "pandas", "scipy.ndimage.gaussian_filter1d", "matplotlib.pyplot", "matplotlib.pyplot", "pandas", "scipy.ndimage.gaussian_filter1d", "matplotlib.pyplot", "seaborn" ], "chunk_id": "function_plot_tune_results_08ba113c" }, { "content": "def feature_visualization(x, module_type: str, stage: int, n: int = 32, save_dir: Path = Path(\"runs/detect/exp\")):\n \"\"\"\n Visualize feature maps of a given model module during inference.\n\n Args:\n x (torch.Tensor): Features to be visualized.\n module_type (str): Module type.\n stage (int): Module stage within the model.\n n (int, optional): Maximum number of feature maps to plot.\n save_dir (Path, optional): Directory to save results.\n \"\"\"\n import matplotlib.pyplot as plt # scope for faster 'import ultralytics'\n\n for m in {\"Detect\", \"Segment\", \"Pose\", \"Classify\", \"OBB\", \"RTDETRDecoder\"}: # all model heads\n if m in module_type:\n return\n if isinstance(x, torch.Tensor):\n _, channels, height, width = x.shape # batch, channels, height, width\n if height > 1 and width > 1:\n f = save_dir / f\"stage{stage}_{module_type.rsplit('.', 1)[-1]}_features.png\" # filename\n\n blocks = torch.chunk(x[0].cpu(), channels, dim=0) # select batch index 0, block by channels\n n = min(n, channels) # number of plots\n _, ax = plt.subplots(math.ceil(n / 8), 8, tight_layout=True) # 8 rows x n/8 cols\n ax = ax.ravel()\n plt.subplots_adjust(wspace=0.05, hspace=0.05)\n for i in range(n):\n ax[i].imshow(blocks[i].squeeze()) # cmap='gray'\n ax[i].axis(\"off\")\n\n LOGGER.info(f\"Saving {f}... ({n}/{channels})\")\n plt.savefig(f, dpi=300, bbox_inches=\"tight\")\n plt.close()\n np.save(str(f.with_suffix(\".npy\")), x[0].cpu().numpy()) # npy save", "chunk_type": "function", "name": "feature_visualization", "file_path": "ultralytics\\ultralytics\\utils\\plotting.py", "start_line": 1004, "end_line": 1037, "start_col": 0, "end_col": 67, "parent_name": null, "docstring": "Visualize feature maps of a given model module during inference.\n\nArgs:\n x (torch.Tensor): Features to be visualized.\n module_type (str): Module type.\n stage (int): Module stage within the model.\n n (int, optional): Maximum number of feature maps to plot.\n save_dir (Path, optional): Directory to save results.", "parameters": [ "x", "module_type: str", "stage: int", "n: int", "save_dir: Path" ], "return_type": null, "decorators": [], "complexity_score": 6, "dependencies": [ "math", "warnings", "pathlib.Path", "typing.Any", "typing.Callable", "typing.Dict", "typing.List", "typing.Optional", "typing.Union", "cv2", "numpy", "torch", "PIL.Image", "PIL.ImageDraw", "PIL.ImageFont", "PIL.__version__", "ultralytics.utils.IS_COLAB", "ultralytics.utils.IS_KAGGLE", "ultralytics.utils.LOGGER", "ultralytics.utils.TryExcept", "ultralytics.utils.ops", "ultralytics.utils.plt_settings", "ultralytics.utils.threaded", "ultralytics.utils.checks.check_font", "ultralytics.utils.checks.check_version", "ultralytics.utils.checks.is_ascii", "ultralytics.utils.files.increment_path", "matplotlib.pyplot", "pandas", "matplotlib.colors.LinearSegmentedColormap", "matplotlib.pyplot", "pandas", "scipy.ndimage.gaussian_filter1d", "matplotlib.pyplot", "matplotlib.pyplot", "pandas", "scipy.ndimage.gaussian_filter1d", "matplotlib.pyplot", "seaborn" ], "chunk_id": "function_feature_visualization_aab84be0" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\ultralytics\\utils\\tal.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_3f7c74dd" }, { "content": "import torch.nn as nn", "chunk_type": "import", "name": "torch.nn", "file_path": "ultralytics\\ultralytics\\utils\\tal.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 21, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch.nn_88e7eb7d" }, { "content": "from . import LOGGER", "chunk_type": "import", "name": "LOGGER", "file_path": "ultralytics\\ultralytics\\utils\\tal.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 20, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_LOGGER_911cfcde" }, { "content": "from .checks import check_version", "chunk_type": "import", "name": "check_version", "file_path": "ultralytics\\ultralytics\\utils\\tal.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 33, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_check_version_2d98f8c9" }, { "content": "from .metrics import bbox_iou, probiou", "chunk_type": "import", "name": "bbox_iou, probiou", "file_path": "ultralytics\\ultralytics\\utils\\tal.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 38, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_bbox_iou, probiou_0f7c9b26" }, { "content": "from .ops import xywhr2xyxyxyxy", "chunk_type": "import", "name": "xywhr2xyxyxyxy", "file_path": "ultralytics\\ultralytics\\utils\\tal.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 31, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_xywhr2xyxyxyxy_847e17a6" }, { "content": "TORCH_1_10 = check_version(torch.__version__, \"1.10.0\")", "chunk_type": "variable", "name": "TORCH_1_10", "file_path": "ultralytics\\ultralytics\\utils\\tal.py", "start_line": 11, "end_line": 11, "start_col": 0, "end_col": 55, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_TORCH_1_10_3afeeabf" }, { "content": "class TaskAlignedAssigner(nn.Module):\n \"\"\"\n A task-aligned assigner for object detection.\n\n This class assigns ground-truth (gt) objects to anchors based on the task-aligned metric, which combines both\n classification and localization information.\n\n Attributes:\n topk (int): The number of top candidates to consider.\n num_classes (int): The number of object classes.\n alpha (float): The alpha parameter for the classification component of the task-aligned metric.\n beta (float): The beta parameter for the localization component of the task-aligned metric.\n eps (float): A small value to prevent division by zero.\n \"\"\"\n\n def __init__(self, topk: int = 13, num_classes: int = 80, alpha: float = 1.0, beta: float = 6.0, eps: float = 1e-9):\n \"\"\"\n Initialize a TaskAlignedAssigner object with customizable hyperparameters.\n\n Args:\n topk (int, optional): The number of top candidates to consider.\n num_classes (int, optional): The number of object classes.\n alpha (float, optional): The alpha parameter for the classification component of the task-aligned metric.\n beta (float, optional): The beta parameter for the localization component of the task-aligned metric.\n eps (float, optional): A small value to prevent division by zero.\n \"\"\"\n super().__init__()\n self.topk = topk\n self.num_classes = num_classes\n self.alpha = alpha\n self.beta = beta\n self.eps = eps\n\n @torch.no_grad()\n def forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):\n \"\"\"\n Compute the task-aligned assignment.\n\n Args:\n pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).\n pd_bboxes (torch.Tensor): Predicted bounding boxes with shape (bs, num_total_anchors, 4).\n anc_points (torch.Tensor): Anchor points with shape (num_total_anchors, 2).\n gt_labels (torch.Tensor): Ground truth labels with shape (bs, n_max_boxes, 1).\n gt_bboxes (torch.Tensor): Ground truth boxes with shape (bs, n_max_boxes, 4).\n mask_gt (torch.Tensor): Mask for valid ground truth boxes with shape (bs, n_max_boxes, 1).\n\n Returns:\n target_labels (torch.Tensor): Target labels with shape (bs, num_total_anchors).\n target_bboxes (torch.Tensor): Target bounding boxes with shape (bs, num_total_anchors, 4).\n target_scores (torch.Tensor): Target scores with shape (bs, num_total_anchors, num_classes).\n fg_mask (torch.Tensor): Foreground mask with shape (bs, num_total_anchors).\n target_gt_idx (torch.Tensor): Target ground truth indices with shape (bs, num_total_anchors).\n\n References:\n https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py\n \"\"\"\n self.bs = pd_scores.shape[0]\n self.n_max_boxes = gt_bboxes.shape[1]\n device = gt_bboxes.device\n\n if self.n_max_boxes == 0:\n return (\n torch.full_like(pd_scores[..., 0], self.num_classes),\n torch.zeros_like(pd_bboxes),\n torch.zeros_like(pd_scores),\n torch.zeros_like(pd_scores[..., 0]),\n torch.zeros_like(pd_scores[..., 0]),\n )\n\n try:\n return self._forward(pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt)\n except torch.cuda.OutOfMemoryError:\n # Move tensors to CPU, compute, then move back to original device\n LOGGER.warning(\"CUDA OutOfMemoryError in TaskAlignedAssigner, using CPU\")\n cpu_tensors = [t.cpu() for t in (pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt)]\n result = self._forward(*cpu_tensors)\n return tuple(t.to(device) for t in result)\n\n def _forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):\n \"\"\"\n Compute the task-aligned assignment.\n\n Args:\n pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).\n pd_bboxes (torch.Tensor): Predicted bounding boxes with shape (bs, num_total_anchors, 4).\n anc_points (torch.Tensor): Anchor points with shape (num_total_anchors, 2).\n gt_labels (torch.Tensor): Ground truth labels with shape (bs, n_max_boxes, 1).\n gt_bboxes (torch.Tensor): Ground truth boxes with shape (bs, n_max_boxes, 4).\n mask_gt (torch.Tensor): Mask for valid ground truth boxes with shape (bs, n_max_boxes, 1).\n\n Returns:\n target_labels (torch.Tensor): Target labels with shape (bs, num_total_anchors).\n target_bboxes (torch.Tensor): Target bounding boxes with shape (bs, num_total_anchors, 4).\n target_scores (torch.Tensor): Target scores with shape (bs, num_total_anchors, num_classes).\n fg_mask (torch.Tensor): Foreground mask with shape (bs, num_total_anchors).\n target_gt_idx (torch.Tensor): Target ground truth indices with shape (bs, num_total_anchors).\n \"\"\"\n mask_pos, align_metric, overlaps = self.get_pos_mask(\n pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt\n )\n\n target_gt_idx, fg_mask, mask_pos = self.select_highest_overlaps(mask_pos, overlaps, self.n_max_boxes)\n\n # Assigned target\n target_labels, target_bboxes, target_scores = self.get_targets(gt_labels, gt_bboxes, target_gt_idx, fg_mask)\n\n # Normalize\n align_metric *= mask_pos\n pos_align_metrics = align_metric.amax(dim=-1, keepdim=True) # b, max_num_obj\n pos_overlaps = (overlaps * mask_pos).amax(dim=-1, keepdim=True) # b, max_num_obj\n norm_align_metric = (align_metric * pos_overlaps / (pos_align_metrics + self.eps)).amax(-2).unsqueeze(-1)\n target_scores = target_scores * norm_align_metric\n\n return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx\n\n def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt):\n \"\"\"\n Get positive mask for each ground truth box.\n\n Args:\n pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).\n pd_bboxes (torch.Tensor): Predicted bounding boxes with shape (bs, num_total_anchors, 4).\n gt_labels (torch.Tensor): Ground truth labels with shape (bs, n_max_boxes, 1).\n gt_bboxes (torch.Tensor): Ground truth boxes with shape (bs, n_max_boxes, 4).\n anc_points (torch.Tensor): Anchor points with shape (num_total_anchors, 2).\n mask_gt (torch.Tensor): Mask for valid ground truth boxes with shape (bs, n_max_boxes, 1).\n\n Returns:\n mask_pos (torch.Tensor): Positive mask with shape (bs, max_num_obj, h*w).\n align_metric (torch.Tensor): Alignment metric with shape (bs, max_num_obj, h*w).\n overlaps (torch.Tensor): Overlaps between predicted and ground truth boxes with shape (bs, max_num_obj, h*w).\n \"\"\"\n mask_in_gts = self.select_candidates_in_gts(anc_points, gt_bboxes)\n # Get anchor_align metric, (b, max_num_obj, h*w)\n align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_in_gts * mask_gt)\n # Get topk_metric mask, (b, max_num_obj, h*w)\n mask_topk = self.select_topk_candidates(align_metric, topk_mask=mask_gt.expand(-1, -1, self.topk).bool())\n # Merge all mask to a final mask, (b, max_num_obj, h*w)\n mask_pos = mask_topk * mask_in_gts * mask_gt\n\n return mask_pos, align_metric, overlaps\n\n def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_gt):\n \"\"\"\n Compute alignment metric given predicted and ground truth bounding boxes.\n\n Args:\n pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).\n pd_bboxes (torch.Tensor): Predicted bounding boxes with shape (bs, num_total_anchors, 4).\n gt_labels (torch.Tensor): Ground truth labels with shape (bs, n_max_boxes, 1).\n gt_bboxes (torch.Tensor): Ground truth boxes with shape (bs, n_max_boxes, 4).\n mask_gt (torch.Tensor): Mask for valid ground truth boxes with shape (bs, n_max_boxes, h*w).\n\n Returns:\n align_metric (torch.Tensor): Alignment metric combining classification and localization.\n overlaps (torch.Tensor): IoU overlaps between predicted and ground truth boxes.\n \"\"\"\n na = pd_bboxes.shape[-2]\n mask_gt = mask_gt.bool() # b, max_num_obj, h*w\n overlaps = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_bboxes.dtype, device=pd_bboxes.device)\n bbox_scores = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_scores.dtype, device=pd_scores.device)\n\n ind = torch.zeros([2, self.bs, self.n_max_boxes], dtype=torch.long) # 2, b, max_num_obj\n ind[0] = torch.arange(end=self.bs).view(-1, 1).expand(-1, self.n_max_boxes) # b, max_num_obj\n ind[1] = gt_labels.squeeze(-1) # b, max_num_obj\n # Get the scores of each grid for each gt cls\n bbox_scores[mask_gt] = pd_scores[ind[0], :, ind[1]][mask_gt] # b, max_num_obj, h*w\n\n # (b, max_num_obj, 1, 4), (b, 1, h*w, 4)\n pd_boxes = pd_bboxes.unsqueeze(1).expand(-1, self.n_max_boxes, -1, -1)[mask_gt]\n gt_boxes = gt_bboxes.unsqueeze(2).expand(-1, -1, na, -1)[mask_gt]\n overlaps[mask_gt] = self.iou_calculation(gt_boxes, pd_boxes)\n\n align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta)\n return align_metric, overlaps\n\n def iou_calculation(self, gt_bboxes, pd_bboxes):\n \"\"\"\n Calculate IoU for horizontal bounding boxes.\n\n Args:\n gt_bboxes (torch.Tensor): Ground truth boxes.\n pd_bboxes (torch.Tensor): Predicted boxes.\n\n Returns:\n (torch.Tensor): IoU values between each pair of boxes.\n \"\"\"\n return bbox_iou(gt_bboxes, pd_bboxes, xywh=False, CIoU=True).squeeze(-1).clamp_(0)\n\n def select_topk_candidates(self, metrics, topk_mask=None):\n \"\"\"\n Select the top-k candidates based on the given metrics.\n\n Args:\n metrics (torch.Tensor): A tensor of shape (b, max_num_obj, h*w), where b is the batch size, max_num_obj is\n the maximum number of objects, and h*w represents the total number of anchor points.\n topk_mask (torch.Tensor, optional): An optional boolean tensor of shape (b, max_num_obj, topk), where\n topk is the number of top candidates to consider. If not provided, the top-k values are automatically\n computed based on the given metrics.\n\n Returns:\n (torch.Tensor): A tensor of shape (b, max_num_obj, h*w) containing the selected top-k candidates.\n \"\"\"\n # (b, max_num_obj, topk)\n topk_metrics, topk_idxs = torch.topk(metrics, self.topk, dim=-1, largest=True)\n if topk_mask is None:\n topk_mask = (topk_metrics.max(-1, keepdim=True)[0] > self.eps).expand_as(topk_idxs)\n # (b, max_num_obj, topk)\n topk_idxs.masked_fill_(~topk_mask, 0)\n\n # (b, max_num_obj, topk, h*w) -> (b, max_num_obj, h*w)\n count_tensor = torch.zeros(metrics.shape, dtype=torch.int8, device=topk_idxs.device)\n ones = torch.ones_like(topk_idxs[:, :, :1], dtype=torch.int8, device=topk_idxs.device)\n for k in range(self.topk):\n # Expand topk_idxs for each value of k and add 1 at the specified positions\n count_tensor.scatter_add_(-1, topk_idxs[:, :, k : k + 1], ones)\n # Filter invalid bboxes\n count_tensor.masked_fill_(count_tensor > 1, 0)\n\n return count_tensor.to(metrics.dtype)\n\n def get_targets(self, gt_labels, gt_bboxes, target_gt_idx, fg_mask):\n \"\"\"\n Compute target labels, target bounding boxes, and target scores for the positive anchor points.\n\n Args:\n gt_labels (torch.Tensor): Ground truth labels of shape (b, max_num_obj, 1), where b is the\n batch size and max_num_obj is the maximum number of objects.\n gt_bboxes (torch.Tensor): Ground truth bounding boxes of shape (b, max_num_obj, 4).\n target_gt_idx (torch.Tensor): Indices of the assigned ground truth objects for positive\n anchor points, with shape (b, h*w), where h*w is the total\n number of anchor points.\n fg_mask (torch.Tensor): A boolean tensor of shape (b, h*w) indicating the positive\n (foreground) anchor points.\n\n Returns:\n target_labels (torch.Tensor): Target labels for positive anchor points with shape (b, h*w).\n target_bboxes (torch.Tensor): Target bounding boxes for positive anchor points with shape (b, h*w, 4).\n target_scores (torch.Tensor): Target scores for positive anchor points with shape (b, h*w, num_classes).\n \"\"\"\n # Assigned target labels, (b, 1)\n batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[..., None]\n target_gt_idx = target_gt_idx + batch_ind * self.n_max_boxes # (b, h*w)\n target_labels = gt_labels.long().flatten()[target_gt_idx] # (b, h*w)\n\n # Assigned target boxes, (b, max_num_obj, 4) -> (b, h*w, 4)\n target_bboxes = gt_bboxes.view(-1, gt_bboxes.shape[-1])[target_gt_idx]\n\n # Assigned target scores\n target_labels.clamp_(0)\n\n # 10x faster than F.one_hot()\n target_scores = torch.zeros(\n (target_labels.shape[0], target_labels.shape[1], self.num_classes),\n dtype=torch.int64,\n device=target_labels.device,\n ) # (b, h*w, 80)\n target_scores.scatter_(2, target_labels.unsqueeze(-1), 1)\n\n fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.num_classes) # (b, h*w, 80)\n target_scores = torch.where(fg_scores_mask > 0, target_scores, 0)\n\n return target_labels, target_bboxes, target_scores\n\n @staticmethod\n def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):\n \"\"\"\n Select positive anchor centers within ground truth bounding boxes.\n\n Args:\n xy_centers (torch.Tensor): Anchor center coordinates, shape (h*w, 2).\n gt_bboxes (torch.Tensor): Ground truth bounding boxes, shape (b, n_boxes, 4).\n eps (float, optional): Small value for numerical stability.\n\n Returns:\n (torch.Tensor): Boolean mask of positive anchors, shape (b, n_boxes, h*w).\n\n Note:\n b: batch size, n_boxes: number of ground truth boxes, h: height, w: width.\n Bounding box format: [x_min, y_min, x_max, y_max].\n \"\"\"\n n_anchors = xy_centers.shape[0]\n bs, n_boxes, _ = gt_bboxes.shape\n lt, rb = gt_bboxes.view(-1, 1, 4).chunk(2, 2) # left-top, right-bottom\n bbox_deltas = torch.cat((xy_centers[None] - lt, rb - xy_centers[None]), dim=2).view(bs, n_boxes, n_anchors, -1)\n return bbox_deltas.amin(3).gt_(eps)\n\n @staticmethod\n def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):\n \"\"\"\n Select anchor boxes with highest IoU when assigned to multiple ground truths.\n\n Args:\n mask_pos (torch.Tensor): Positive mask, shape (b, n_max_boxes, h*w).\n overlaps (torch.Tensor): IoU overlaps, shape (b, n_max_boxes, h*w).\n n_max_boxes (int): Maximum number of ground truth boxes.\n\n Returns:\n target_gt_idx (torch.Tensor): Indices of assigned ground truths, shape (b, h*w).\n fg_mask (torch.Tensor): Foreground mask, shape (b, h*w).\n mask_pos (torch.Tensor): Updated positive mask, shape (b, n_max_boxes, h*w).\n \"\"\"\n # Convert (b, n_max_boxes, h*w) -> (b, h*w)\n fg_mask = mask_pos.sum(-2)\n if fg_mask.max() > 1: # one anchor is assigned to multiple gt_bboxes\n mask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, n_max_boxes, -1) # (b, n_max_boxes, h*w)\n max_overlaps_idx = overlaps.argmax(1) # (b, h*w)\n\n is_max_overlaps = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device)\n is_max_overlaps.scatter_(1, max_overlaps_idx.unsqueeze(1), 1)\n\n mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos).float() # (b, n_max_boxes, h*w)\n fg_mask = mask_pos.sum(-2)\n # Find each grid serve which gt(index)\n target_gt_idx = mask_pos.argmax(-2) # (b, h*w)\n return target_gt_idx, fg_mask, mask_pos", "chunk_type": "class", "name": "TaskAlignedAssigner", "file_path": "ultralytics\\ultralytics\\utils\\tal.py", "start_line": 14, "end_line": 329, "start_col": 0, "end_col": 47, "parent_name": null, "docstring": "A task-aligned assigner for object detection.\n\nThis class assigns ground-truth (gt) objects to anchors based on the task-aligned metric, which combines both\nclassification and localization information.\n\nAttributes:\n topk (int): The number of top candidates to consider.\n num_classes (int): The number of object classes.\n alpha (float): The alpha parameter for the classification component of the task-aligned metric.\n beta (float): The beta parameter for the localization component of the task-aligned metric.\n eps (float): A small value to prevent division by zero.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "torch", "torch.nn", "LOGGER", "checks.check_version", "metrics.bbox_iou", "metrics.probiou", "ops.xywhr2xyxyxyxy", "nn.Module" ], "chunk_id": "class_TaskAlignedAssigner_d717a5d2" }, { "content": "class RotatedTaskAlignedAssigner(TaskAlignedAssigner):\n \"\"\"Assigns ground-truth objects to rotated bounding boxes using a task-aligned metric.\"\"\"\n\n def iou_calculation(self, gt_bboxes, pd_bboxes):\n \"\"\"Calculate IoU for rotated bounding boxes.\"\"\"\n return probiou(gt_bboxes, pd_bboxes).squeeze(-1).clamp_(0)\n\n @staticmethod\n def select_candidates_in_gts(xy_centers, gt_bboxes):\n \"\"\"\n Select the positive anchor center in gt for rotated bounding boxes.\n\n Args:\n xy_centers (torch.Tensor): Anchor center coordinates with shape (h*w, 2).\n gt_bboxes (torch.Tensor): Ground truth bounding boxes with shape (b, n_boxes, 5).\n\n Returns:\n (torch.Tensor): Boolean mask of positive anchors with shape (b, n_boxes, h*w).\n \"\"\"\n # (b, n_boxes, 5) --> (b, n_boxes, 4, 2)\n corners = xywhr2xyxyxyxy(gt_bboxes)\n # (b, n_boxes, 1, 2)\n a, b, _, d = corners.split(1, dim=-2)\n ab = b - a\n ad = d - a\n\n # (b, n_boxes, h*w, 2)\n ap = xy_centers - a\n norm_ab = (ab * ab).sum(dim=-1)\n norm_ad = (ad * ad).sum(dim=-1)\n ap_dot_ab = (ap * ab).sum(dim=-1)\n ap_dot_ad = (ap * ad).sum(dim=-1)\n return (ap_dot_ab >= 0) & (ap_dot_ab <= norm_ab) & (ap_dot_ad >= 0) & (ap_dot_ad <= norm_ad) # is_in_box", "chunk_type": "class", "name": "RotatedTaskAlignedAssigner", "file_path": "ultralytics\\ultralytics\\utils\\tal.py", "start_line": 332, "end_line": 364, "start_col": 0, "end_col": 100, "parent_name": null, "docstring": "Assigns ground-truth objects to rotated bounding boxes using a task-aligned metric.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "torch", "torch.nn", "LOGGER", "checks.check_version", "metrics.bbox_iou", "metrics.probiou", "ops.xywhr2xyxyxyxy", "TaskAlignedAssigner" ], "chunk_id": "class_RotatedTaskAlignedAssigner_483ffc95" }, { "content": "def make_anchors(feats, strides, grid_cell_offset=0.5):\n \"\"\"Generate anchors from features.\"\"\"\n anchor_points, stride_tensor = [], []\n assert feats is not None\n dtype, device = feats[0].dtype, feats[0].device\n for i, stride in enumerate(strides):\n h, w = feats[i].shape[2:] if isinstance(feats, list) else (int(feats[i][0]), int(feats[i][1]))\n sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset # shift x\n sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset # shift y\n sy, sx = torch.meshgrid(sy, sx, indexing=\"ij\") if TORCH_1_10 else torch.meshgrid(sy, sx)\n anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2))\n stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))\n return torch.cat(anchor_points), torch.cat(stride_tensor)", "chunk_type": "function", "name": "make_anchors", "file_path": "ultralytics\\ultralytics\\utils\\tal.py", "start_line": 367, "end_line": 379, "start_col": 0, "end_col": 61, "parent_name": null, "docstring": "Generate anchors from features.", "parameters": [ "feats", "strides", "grid_cell_offset" ], "return_type": null, "decorators": [], "complexity_score": 2, "dependencies": [ "torch", "torch.nn", "LOGGER", "checks.check_version", "metrics.bbox_iou", "metrics.probiou", "ops.xywhr2xyxyxyxy" ], "chunk_id": "function_make_anchors_d7e2eb74" }, { "content": "def dist2bbox(distance, anchor_points, xywh=True, dim=-1):\n \"\"\"Transform distance(ltrb) to box(xywh or xyxy).\"\"\"\n lt, rb = distance.chunk(2, dim)\n x1y1 = anchor_points - lt\n x2y2 = anchor_points + rb\n if xywh:\n c_xy = (x1y1 + x2y2) / 2\n wh = x2y2 - x1y1\n return torch.cat((c_xy, wh), dim) # xywh bbox\n return torch.cat((x1y1, x2y2), dim) # xyxy bbox", "chunk_type": "function", "name": "dist2bbox", "file_path": "ultralytics\\ultralytics\\utils\\tal.py", "start_line": 382, "end_line": 391, "start_col": 0, "end_col": 39, "parent_name": null, "docstring": "Transform distance(ltrb) to box(xywh or xyxy).", "parameters": [ "distance", "anchor_points", "xywh", "dim" ], "return_type": null, "decorators": [], "complexity_score": 2, "dependencies": [ "torch", "torch.nn", "LOGGER", "checks.check_version", "metrics.bbox_iou", "metrics.probiou", "ops.xywhr2xyxyxyxy" ], "chunk_id": "function_dist2bbox_efa5c776" }, { "content": "def bbox2dist(anchor_points, bbox, reg_max):\n \"\"\"Transform bbox(xyxy) to dist(ltrb).\"\"\"\n x1y1, x2y2 = bbox.chunk(2, -1)\n return torch.cat((anchor_points - x1y1, x2y2 - anchor_points), -1).clamp_(0, reg_max - 0.01) # dist (lt, rb)", "chunk_type": "function", "name": "bbox2dist", "file_path": "ultralytics\\ultralytics\\utils\\tal.py", "start_line": 394, "end_line": 397, "start_col": 0, "end_col": 96, "parent_name": null, "docstring": "Transform bbox(xyxy) to dist(ltrb).", "parameters": [ "anchor_points", "bbox", "reg_max" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "torch", "torch.nn", "LOGGER", "checks.check_version", "metrics.bbox_iou", "metrics.probiou", "ops.xywhr2xyxyxyxy" ], "chunk_id": "function_bbox2dist_7eed802e" }, { "content": "def dist2rbox(pred_dist, pred_angle, anchor_points, dim=-1):\n \"\"\"\n Decode predicted rotated bounding box coordinates from anchor points and distribution.\n\n Args:\n pred_dist (torch.Tensor): Predicted rotated distance with shape (bs, h*w, 4).\n pred_angle (torch.Tensor): Predicted angle with shape (bs, h*w, 1).\n anchor_points (torch.Tensor): Anchor points with shape (h*w, 2).\n dim (int, optional): Dimension along which to split.\n\n Returns:\n (torch.Tensor): Predicted rotated bounding boxes with shape (bs, h*w, 4).\n \"\"\"\n lt, rb = pred_dist.split(2, dim=dim)\n cos, sin = torch.cos(pred_angle), torch.sin(pred_angle)\n # (bs, h*w, 1)\n xf, yf = ((rb - lt) / 2).split(1, dim=dim)\n x, y = xf * cos - yf * sin, xf * sin + yf * cos\n xy = torch.cat([x, y], dim=dim) + anchor_points\n return torch.cat([xy, lt + rb], dim=dim)", "chunk_type": "function", "name": "dist2rbox", "file_path": "ultralytics\\ultralytics\\utils\\tal.py", "start_line": 400, "end_line": 419, "start_col": 0, "end_col": 44, "parent_name": null, "docstring": "Decode predicted rotated bounding box coordinates from anchor points and distribution.\n\nArgs:\n pred_dist (torch.Tensor): Predicted rotated distance with shape (bs, h*w, 4).\n pred_angle (torch.Tensor): Predicted angle with shape (bs, h*w, 1).\n anchor_points (torch.Tensor): Anchor points with shape (h*w, 2).\n dim (int, optional): Dimension along which to split.\n\nReturns:\n (torch.Tensor): Predicted rotated bounding boxes with shape (bs, h*w, 4).", "parameters": [ "pred_dist", "pred_angle", "anchor_points", "dim" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "torch", "torch.nn", "LOGGER", "checks.check_version", "metrics.bbox_iou", "metrics.probiou", "ops.xywhr2xyxyxyxy" ], "chunk_id": "function_dist2rbox_40bbab6d" }, { "content": "import functools", "chunk_type": "import", "name": "functools", "file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 16, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_functools_1761ad8e" }, { "content": "import gc", "chunk_type": "import", "name": "gc", "file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 9, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_gc_79b36253" }, { "content": "import math", "chunk_type": "import", "name": "math", "file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 11, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_math_94b5469f" }, { "content": "import os", "chunk_type": "import", "name": "os", "file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 9, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_os_196299bc" }, { "content": "import random", "chunk_type": "import", "name": "random", "file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 13, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_random_11b50b2e" }, { "content": "import time", "chunk_type": "import", "name": "time", "file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 11, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_time_ba590618" }, { "content": "from contextlib import contextmanager", "chunk_type": "import", "name": "contextmanager", "file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 37, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_contextmanager_db6adbfc" }, { "content": "from copy import deepcopy", "chunk_type": "import", "name": "deepcopy", "file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py", "start_line": 10, "end_line": 10, "start_col": 0, "end_col": 25, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_deepcopy_87deb03a" }, { "content": "from datetime import datetime", "chunk_type": "import", "name": "datetime", "file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py", "start_line": 11, "end_line": 11, "start_col": 0, "end_col": 29, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_datetime_7e96701c" }, { "content": "from pathlib import Path", "chunk_type": "import", "name": "Path", "file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py", "start_line": 12, "end_line": 12, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Path_229a89f2" }, { "content": "from typing import Any, Dict, Union", "chunk_type": "import", "name": "Any, Dict, Union", "file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py", "start_line": 13, "end_line": 13, "start_col": 0, "end_col": 35, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Any, Dict, Union_eec6484e" }, { "content": "import numpy as np", "chunk_type": "import", "name": "numpy", "file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py", "start_line": 15, "end_line": 15, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_numpy_38987177" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py", "start_line": 16, "end_line": 16, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_b805e815" }, { "content": "import torch.distributed as dist", "chunk_type": "import", "name": "torch.distributed", "file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py", "start_line": 17, "end_line": 17, "start_col": 0, "end_col": 32, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch.distributed_ce6f6f07" }, { "content": "import torch.nn as nn", "chunk_type": "import", "name": "torch.nn", "file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py", "start_line": 18, "end_line": 18, "start_col": 0, "end_col": 21, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch.nn_25306e3a" }, { "content": "import torch.nn.functional as F", "chunk_type": "import", "name": "torch.nn.functional", "file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py", "start_line": 19, "end_line": 19, "start_col": 0, "end_col": 31, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch.nn.functional_4d818dae" }, { "content": "from ultralytics import __version__", "chunk_type": "import", "name": "__version__", "file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py", "start_line": 21, "end_line": 21, "start_col": 0, "end_col": 35, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import___version___aca33539" }, { "content": "from ultralytics.utils import (\n DEFAULT_CFG_DICT,\n DEFAULT_CFG_KEYS,\n LOGGER,\n NUM_THREADS,\n PYTHON_VERSION,\n TORCHVISION_VERSION,\n WINDOWS,\n colorstr,\n)", "chunk_type": "import", "name": "DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, NUM_THREADS, PYTHON_VERSION, TORCHVISION_VERSION, WINDOWS, colorstr", "file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py", "start_line": 22, "end_line": 31, "start_col": 0, "end_col": 1, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, NUM_THREADS, PYTHON_VERSION, TORCHVISION_VERSION, WINDOWS, colorstr_60e35477" }, { "content": "from ultralytics.utils.checks import check_version", "chunk_type": "import", "name": "check_version", "file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py", "start_line": 32, "end_line": 32, "start_col": 0, "end_col": 50, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_check_version_3bb078e0" }, { "content": "from ultralytics.utils.patches import torch_load", "chunk_type": "import", "name": "torch_load", "file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py", "start_line": 33, "end_line": 33, "start_col": 0, "end_col": 48, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_load_bff04f79" }, { "content": "TORCH_1_9 = check_version(torch.__version__, \"1.9.0\")", "chunk_type": "variable", "name": "TORCH_1_9", "file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py", "start_line": 36, "end_line": 36, "start_col": 0, "end_col": 53, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_TORCH_1_9_deaf6ac0" }, { "content": "TORCH_1_13 = check_version(torch.__version__, \"1.13.0\")", "chunk_type": "variable", "name": "TORCH_1_13", "file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py", "start_line": 37, "end_line": 37, "start_col": 0, "end_col": 55, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_TORCH_1_13_0b85fd27" }, { "content": "TORCH_2_0 = check_version(torch.__version__, \"2.0.0\")", "chunk_type": "variable", "name": "TORCH_2_0", "file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py", "start_line": 38, "end_line": 38, "start_col": 0, "end_col": 53, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_TORCH_2_0_d41a9c8d" }, { "content": "TORCH_2_4 = check_version(torch.__version__, \"2.4.0\")", "chunk_type": "variable", "name": "TORCH_2_4", "file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py", "start_line": 39, "end_line": 39, "start_col": 0, "end_col": 53, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_TORCH_2_4_259b5df7" }, { "content": "TORCHVISION_0_10 = check_version(TORCHVISION_VERSION, \"0.10.0\")", "chunk_type": "variable", "name": "TORCHVISION_0_10", "file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py", "start_line": 40, "end_line": 40, "start_col": 0, "end_col": 63, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_TORCHVISION_0_10_63667314" }, { "content": "TORCHVISION_0_11 = check_version(TORCHVISION_VERSION, \"0.11.0\")", "chunk_type": "variable", "name": "TORCHVISION_0_11", "file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py", "start_line": 41, "end_line": 41, "start_col": 0, "end_col": 63, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_TORCHVISION_0_11_c6dd5441" }, { "content": "TORCHVISION_0_13 = check_version(TORCHVISION_VERSION, \"0.13.0\")", "chunk_type": "variable", "name": "TORCHVISION_0_13", "file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py", "start_line": 42, "end_line": 42, "start_col": 0, "end_col": 63, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_TORCHVISION_0_13_9aa804a6" }, { "content": "TORCHVISION_0_18 = check_version(TORCHVISION_VERSION, \"0.18.0\")", "chunk_type": "variable", "name": "TORCHVISION_0_18", "file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py", "start_line": 43, "end_line": 43, "start_col": 0, "end_col": 63, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_TORCHVISION_0_18_240f011e" }, { "content": "def torch_distributed_zero_first(local_rank: int):\n \"\"\"Ensure all processes in distributed training wait for the local master (rank 0) to complete a task first.\"\"\"\n initialized = dist.is_available() and dist.is_initialized()\n use_ids = initialized and dist.get_backend() == \"nccl\"\n\n if initialized and local_rank not in {-1, 0}:\n dist.barrier(device_ids=[local_rank]) if use_ids else dist.barrier()\n yield\n if initialized and local_rank == 0:\n dist.barrier(device_ids=[local_rank]) if use_ids else dist.barrier()", "chunk_type": "function", "name": "torch_distributed_zero_first", "file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py", "start_line": 52, "end_line": 61, "start_col": 0, "end_col": 76, "parent_name": null, "docstring": "Ensure all processes in distributed training wait for the local master (rank 0) to complete a task first.", "parameters": [ "local_rank: int" ], "return_type": null, "decorators": [ "contextmanager" ], "complexity_score": 3, "dependencies": [ "functools", "gc", "math", "os", "random", "time", "contextlib.contextmanager", "copy.deepcopy", "datetime.datetime", "pathlib.Path", "typing.Any", "typing.Dict", "typing.Union", "numpy", "torch", "torch.distributed", "torch.nn", "torch.nn.functional", "ultralytics.__version__", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.DEFAULT_CFG_KEYS", "ultralytics.utils.LOGGER", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.TORCHVISION_VERSION", "ultralytics.utils.WINDOWS", "ultralytics.utils.colorstr", "ultralytics.utils.checks.check_version", "ultralytics.utils.patches.torch_load", "ultralytics.utils.PERSISTENT_CACHE", "ultralytics.utils.autodevice.GPUInfo", "ultralytics.utils.benchmarks.ProfileModels", "thop", "thop", "cpuinfo" ], "chunk_id": "function_torch_distributed_zero_first_f35ec42b" }, { "content": "def smart_inference_mode():\n \"\"\"Apply torch.inference_mode() decorator if torch>=1.9.0 else torch.no_grad() decorator.\"\"\"\n\n def decorate(fn):\n \"\"\"Apply appropriate torch decorator for inference mode based on torch version.\"\"\"\n if TORCH_1_9 and torch.is_inference_mode_enabled():\n return fn # already in inference_mode, act as a pass-through\n else:\n return (torch.inference_mode if TORCH_1_9 else torch.no_grad)()(fn)\n\n return decorate", "chunk_type": "function", "name": "smart_inference_mode", "file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py", "start_line": 64, "end_line": 74, "start_col": 0, "end_col": 19, "parent_name": null, "docstring": "Apply torch.inference_mode() decorator if torch>=1.9.0 else torch.no_grad() decorator.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 2, "dependencies": [ "functools", "gc", "math", "os", "random", "time", "contextlib.contextmanager", "copy.deepcopy", "datetime.datetime", "pathlib.Path", "typing.Any", "typing.Dict", "typing.Union", "numpy", "torch", "torch.distributed", "torch.nn", "torch.nn.functional", "ultralytics.__version__", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.DEFAULT_CFG_KEYS", "ultralytics.utils.LOGGER", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.TORCHVISION_VERSION", "ultralytics.utils.WINDOWS", "ultralytics.utils.colorstr", "ultralytics.utils.checks.check_version", "ultralytics.utils.patches.torch_load", "ultralytics.utils.PERSISTENT_CACHE", "ultralytics.utils.autodevice.GPUInfo", "ultralytics.utils.benchmarks.ProfileModels", "thop", "thop", "cpuinfo" ], "chunk_id": "function_smart_inference_mode_a69c8fae" }, { "content": "def autocast(enabled: bool, device: str = \"cuda\"):\n \"\"\"\n Get the appropriate autocast context manager based on PyTorch version and AMP setting.\n\n This function returns a context manager for automatic mixed precision (AMP) training that is compatible with both\n older and newer versions of PyTorch. It handles the differences in the autocast API between PyTorch versions.\n\n Args:\n enabled (bool): Whether to enable automatic mixed precision.\n device (str, optional): The device to use for autocast.\n\n Returns:\n (torch.amp.autocast): The appropriate autocast context manager.\n\n Notes:\n - For PyTorch versions 1.13 and newer, it uses `torch.amp.autocast`.\n - For older versions, it uses `torch.cuda.autocast`.\n\n Examples:\n >>> with autocast(enabled=True):\n ... # Your mixed precision operations here\n ... pass\n \"\"\"\n if TORCH_1_13:\n return torch.amp.autocast(device, enabled=enabled)\n else:\n return torch.cuda.amp.autocast(enabled)", "chunk_type": "function", "name": "autocast", "file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py", "start_line": 77, "end_line": 103, "start_col": 0, "end_col": 47, "parent_name": null, "docstring": "Get the appropriate autocast context manager based on PyTorch version and AMP setting.\n\nThis function returns a context manager for automatic mixed precision (AMP) training that is compatible with both\nolder and newer versions of PyTorch. It handles the differences in the autocast API between PyTorch versions.\n\nArgs:\n enabled (bool): Whether to enable automatic mixed precision.\n device (str, optional): The device to use for autocast.\n\nReturns:\n (torch.amp.autocast): The appropriate autocast context manager.\n\nNotes:\n - For PyTorch versions 1.13 and newer, it uses `torch.amp.autocast`.\n - For older versions, it uses `torch.cuda.autocast`.\n\nExamples:\n >>> with autocast(enabled=True):\n ... # Your mixed precision operations here\n ... pass", "parameters": [ "enabled: bool", "device: str" ], "return_type": null, "decorators": [], "complexity_score": 2, "dependencies": [ "functools", "gc", "math", "os", "random", "time", "contextlib.contextmanager", "copy.deepcopy", "datetime.datetime", "pathlib.Path", "typing.Any", "typing.Dict", "typing.Union", "numpy", "torch", "torch.distributed", "torch.nn", "torch.nn.functional", "ultralytics.__version__", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.DEFAULT_CFG_KEYS", "ultralytics.utils.LOGGER", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.TORCHVISION_VERSION", "ultralytics.utils.WINDOWS", "ultralytics.utils.colorstr", "ultralytics.utils.checks.check_version", "ultralytics.utils.patches.torch_load", "ultralytics.utils.PERSISTENT_CACHE", "ultralytics.utils.autodevice.GPUInfo", "ultralytics.utils.benchmarks.ProfileModels", "thop", "thop", "cpuinfo" ], "chunk_id": "function_autocast_34300a42" }, { "content": "def get_cpu_info():\n \"\"\"Return a string with system CPU information, i.e. 'Apple M2'.\"\"\"\n from ultralytics.utils import PERSISTENT_CACHE # avoid circular import error\n\n if \"cpu_info\" not in PERSISTENT_CACHE:\n try:\n import cpuinfo # pip install py-cpuinfo\n\n k = \"brand_raw\", \"hardware_raw\", \"arch_string_raw\" # keys sorted by preference\n info = cpuinfo.get_cpu_info() # info dict\n string = info.get(k[0] if k[0] in info else k[1] if k[1] in info else k[2], \"unknown\")\n PERSISTENT_CACHE[\"cpu_info\"] = string.replace(\"(R)\", \"\").replace(\"CPU \", \"\").replace(\"@ \", \"\")\n except Exception:\n pass\n return PERSISTENT_CACHE.get(\"cpu_info\", \"unknown\")", "chunk_type": "function", "name": "get_cpu_info", "file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py", "start_line": 107, "end_line": 121, "start_col": 0, "end_col": 54, "parent_name": null, "docstring": "Return a string with system CPU information, i.e. 'Apple M2'.", "parameters": [], "return_type": null, "decorators": [ "functools.lru_cache" ], "complexity_score": 3, "dependencies": [ "functools", "gc", "math", "os", "random", "time", "contextlib.contextmanager", "copy.deepcopy", "datetime.datetime", "pathlib.Path", "typing.Any", "typing.Dict", "typing.Union", "numpy", "torch", "torch.distributed", "torch.nn", "torch.nn.functional", "ultralytics.__version__", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.DEFAULT_CFG_KEYS", "ultralytics.utils.LOGGER", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.TORCHVISION_VERSION", "ultralytics.utils.WINDOWS", "ultralytics.utils.colorstr", "ultralytics.utils.checks.check_version", "ultralytics.utils.patches.torch_load", "ultralytics.utils.PERSISTENT_CACHE", "ultralytics.utils.autodevice.GPUInfo", "ultralytics.utils.benchmarks.ProfileModels", "thop", "thop", "cpuinfo" ], "chunk_id": "function_get_cpu_info_d2bd4308" }, { "content": "def get_gpu_info(index):\n \"\"\"Return a string with system GPU information, i.e. 'Tesla T4, 15102MiB'.\"\"\"\n properties = torch.cuda.get_device_properties(index)\n return f\"{properties.name}, {properties.total_memory / (1 << 20):.0f}MiB\"", "chunk_type": "function", "name": "get_gpu_info", "file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py", "start_line": 125, "end_line": 128, "start_col": 0, "end_col": 77, "parent_name": null, "docstring": "Return a string with system GPU information, i.e. 'Tesla T4, 15102MiB'.", "parameters": [ "index" ], "return_type": null, "decorators": [ "functools.lru_cache" ], "complexity_score": 1, "dependencies": [ "functools", "gc", "math", "os", "random", "time", "contextlib.contextmanager", "copy.deepcopy", "datetime.datetime", "pathlib.Path", "typing.Any", "typing.Dict", "typing.Union", "numpy", "torch", "torch.distributed", "torch.nn", "torch.nn.functional", "ultralytics.__version__", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.DEFAULT_CFG_KEYS", "ultralytics.utils.LOGGER", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.TORCHVISION_VERSION", "ultralytics.utils.WINDOWS", "ultralytics.utils.colorstr", "ultralytics.utils.checks.check_version", "ultralytics.utils.patches.torch_load", "ultralytics.utils.PERSISTENT_CACHE", "ultralytics.utils.autodevice.GPUInfo", "ultralytics.utils.benchmarks.ProfileModels", "thop", "thop", "cpuinfo" ], "chunk_id": "function_get_gpu_info_3a395543" }, { "content": "def select_device(device=\"\", batch=0, newline=False, verbose=True):\n \"\"\"\n Select the appropriate PyTorch device based on the provided arguments.\n\n The function takes a string specifying the device or a torch.device object and returns a torch.device object\n representing the selected device. The function also validates the number of available devices and raises an\n exception if the requested device(s) are not available.\n\n Args:\n device (str | torch.device, optional): Device string or torch.device object. Options are 'None', 'cpu', or\n 'cuda', or '0' or '0,1,2,3'. Auto-selects the first available GPU, or CPU if no GPU is available.\n batch (int, optional): Batch size being used in your model.\n newline (bool, optional): If True, adds a newline at the end of the log string.\n verbose (bool, optional): If True, logs the device information.\n\n Returns:\n (torch.device): Selected device.\n\n Raises:\n ValueError: If the specified device is not available or if the batch size is not a multiple of the number of\n devices when using multiple GPUs.\n\n Examples:\n >>> select_device(\"cuda:0\")\n device(type='cuda', index=0)\n\n >>> select_device(\"cpu\")\n device(type='cpu')\n\n Notes:\n Sets the 'CUDA_VISIBLE_DEVICES' environment variable for specifying which GPUs to use.\n \"\"\"\n if isinstance(device, torch.device) or str(device).startswith((\"tpu\", \"intel\")):\n return device\n\n s = f\"Ultralytics {__version__} 🚀 Python-{PYTHON_VERSION} torch-{torch.__version__} \"\n device = str(device).lower()\n for remove in \"cuda:\", \"none\", \"(\", \")\", \"[\", \"]\", \"'\", \" \":\n device = device.replace(remove, \"\") # to string, 'cuda:0' -> '0' and '(0, 1)' -> '0,1'\n\n # Auto-select GPUs\n if \"-1\" in device:\n from ultralytics.utils.autodevice import GPUInfo\n\n # Replace each -1 with a selected GPU or remove it\n parts = device.split(\",\")\n selected = GPUInfo().select_idle_gpu(count=parts.count(\"-1\"), min_memory_fraction=0.2)\n for i in range(len(parts)):\n if parts[i] == \"-1\":\n parts[i] = str(selected.pop(0)) if selected else \"\"\n device = \",\".join(p for p in parts if p)\n\n cpu = device == \"cpu\"\n mps = device in {\"mps\", \"mps:0\"} # Apple Metal Performance Shaders (MPS)\n if cpu or mps:\n os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"-1\" # force torch.cuda.is_available() = False\n elif device: # non-cpu device requested\n if device == \"cuda\":\n device = \"0\"\n if \",\" in device:\n device = \",\".join([x for x in device.split(\",\") if x]) # remove sequential commas, i.e. \"0,,1\" -> \"0,1\"\n visible = os.environ.get(\"CUDA_VISIBLE_DEVICES\", None)\n os.environ[\"CUDA_VISIBLE_DEVICES\"] = device # set environment variable - must be before assert is_available()\n if not (torch.cuda.is_available() and torch.cuda.device_count() >= len(device.split(\",\"))):\n LOGGER.info(s)\n install = (\n \"See https://pytorch.org/get-started/locally/ for up-to-date torch install instructions if no \"\n \"CUDA devices are seen by torch.\\n\"\n if torch.cuda.device_count() == 0\n else \"\"\n )\n raise ValueError(\n f\"Invalid CUDA 'device={device}' requested.\"\n f\" Use 'device=cpu' or pass valid CUDA device(s) if available,\"\n f\" i.e. 'device=0' or 'device=0,1,2,3' for Multi-GPU.\\n\"\n f\"\\ntorch.cuda.is_available(): {torch.cuda.is_available()}\"\n f\"\\ntorch.cuda.device_count(): {torch.cuda.device_count()}\"\n f\"\\nos.environ['CUDA_VISIBLE_DEVICES']: {visible}\\n\"\n f\"{install}\"\n )\n\n if not cpu and not mps and torch.cuda.is_available(): # prefer GPU if available\n devices = device.split(\",\") if device else \"0\" # i.e. \"0,1\" -> [\"0\", \"1\"]\n n = len(devices) # device count\n if n > 1: # multi-GPU\n if batch < 1:\n raise ValueError(\n \"AutoBatch with batch<1 not supported for Multi-GPU training, \"\n f\"please specify a valid batch size multiple of GPU count {n}, i.e. batch={n * 8}.\"\n )\n if batch >= 0 and batch % n != 0: # check batch_size is divisible by device_count\n raise ValueError(\n f\"'batch={batch}' must be a multiple of GPU count {n}. Try 'batch={batch // n * n}' or \"\n f\"'batch={batch // n * n + n}', the nearest batch sizes evenly divisible by {n}.\"\n )\n space = \" \" * (len(s) + 1)\n for i, d in enumerate(devices):\n s += f\"{'' if i == 0 else space}CUDA:{d} ({get_gpu_info(i)})\\n\" # bytes to MB\n arg = \"cuda:0\"\n elif mps and TORCH_2_0 and torch.backends.mps.is_available():\n # Prefer MPS if available\n s += f\"MPS ({get_cpu_info()})\\n\"\n arg = \"mps\"\n else: # revert to CPU\n s += f\"CPU ({get_cpu_info()})\\n\"\n arg = \"cpu\"\n\n if arg in {\"cpu\", \"mps\"}:\n torch.set_num_threads(NUM_THREADS) # reset OMP_NUM_THREADS for cpu training\n if verbose:\n LOGGER.info(s if newline else s.rstrip())\n return torch.device(arg)", "chunk_type": "function", "name": "select_device", "file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py", "start_line": 131, "end_line": 242, "start_col": 0, "end_col": 28, "parent_name": null, "docstring": "Select the appropriate PyTorch device based on the provided arguments.\n\nThe function takes a string specifying the device or a torch.device object and returns a torch.device object\nrepresenting the selected device. The function also validates the number of available devices and raises an\nexception if the requested device(s) are not available.\n\nArgs:\n device (str | torch.device, optional): Device string or torch.device object. Options are 'None', 'cpu', or\n 'cuda', or '0' or '0,1,2,3'. Auto-selects the first available GPU, or CPU if no GPU is available.\n batch (int, optional): Batch size being used in your model.\n newline (bool, optional): If True, adds a newline at the end of the log string.\n verbose (bool, optional): If True, logs the device information.\n\nReturns:\n (torch.device): Selected device.\n\nRaises:\n ValueError: If the specified device is not available or if the batch size is not a multiple of the number of\n devices when using multiple GPUs.\n\nExamples:\n >>> select_device(\"cuda:0\")\n device(type='cuda', index=0)\n\n >>> select_device(\"cpu\")\n device(type='cpu')\n\nNotes:\n Sets the 'CUDA_VISIBLE_DEVICES' environment variable for specifying which GPUs to use.", "parameters": [ "device", "batch", "newline", "verbose" ], "return_type": null, "decorators": [], "complexity_score": 21, "dependencies": [ "functools", "gc", "math", "os", "random", "time", "contextlib.contextmanager", "copy.deepcopy", "datetime.datetime", "pathlib.Path", "typing.Any", "typing.Dict", "typing.Union", "numpy", "torch", "torch.distributed", "torch.nn", "torch.nn.functional", "ultralytics.__version__", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.DEFAULT_CFG_KEYS", "ultralytics.utils.LOGGER", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.TORCHVISION_VERSION", "ultralytics.utils.WINDOWS", "ultralytics.utils.colorstr", "ultralytics.utils.checks.check_version", "ultralytics.utils.patches.torch_load", "ultralytics.utils.PERSISTENT_CACHE", "ultralytics.utils.autodevice.GPUInfo", "ultralytics.utils.benchmarks.ProfileModels", "thop", "thop", "cpuinfo" ], "chunk_id": "function_select_device_6a27eb04" }, { "content": "def time_sync():\n \"\"\"Return PyTorch-accurate time.\"\"\"\n if torch.cuda.is_available():\n torch.cuda.synchronize()\n return time.time()", "chunk_type": "function", "name": "time_sync", "file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py", "start_line": 245, "end_line": 249, "start_col": 0, "end_col": 22, "parent_name": null, "docstring": "Return PyTorch-accurate time.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 2, "dependencies": [ "functools", "gc", "math", "os", "random", "time", "contextlib.contextmanager", "copy.deepcopy", "datetime.datetime", "pathlib.Path", "typing.Any", "typing.Dict", "typing.Union", "numpy", "torch", "torch.distributed", "torch.nn", "torch.nn.functional", "ultralytics.__version__", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.DEFAULT_CFG_KEYS", "ultralytics.utils.LOGGER", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.TORCHVISION_VERSION", "ultralytics.utils.WINDOWS", "ultralytics.utils.colorstr", "ultralytics.utils.checks.check_version", "ultralytics.utils.patches.torch_load", "ultralytics.utils.PERSISTENT_CACHE", "ultralytics.utils.autodevice.GPUInfo", "ultralytics.utils.benchmarks.ProfileModels", "thop", "thop", "cpuinfo" ], "chunk_id": "function_time_sync_6134daef" }, { "content": "def fuse_conv_and_bn(conv, bn):\n \"\"\"Fuse Conv2d() and BatchNorm2d() layers.\"\"\"\n fusedconv = (\n nn.Conv2d(\n conv.in_channels,\n conv.out_channels,\n kernel_size=conv.kernel_size,\n stride=conv.stride,\n padding=conv.padding,\n dilation=conv.dilation,\n groups=conv.groups,\n bias=True,\n )\n .requires_grad_(False)\n .to(conv.weight.device)\n )\n\n # Prepare filters\n w_conv = conv.weight.view(conv.out_channels, -1)\n w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))\n fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))\n\n # Prepare spatial bias\n b_conv = (\n torch.zeros(conv.weight.shape[0], dtype=conv.weight.dtype, device=conv.weight.device)\n if conv.bias is None\n else conv.bias\n )\n b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))\n fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)\n\n return fusedconv", "chunk_type": "function", "name": "fuse_conv_and_bn", "file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py", "start_line": 252, "end_line": 283, "start_col": 0, "end_col": 20, "parent_name": null, "docstring": "Fuse Conv2d() and BatchNorm2d() layers.", "parameters": [ "conv", "bn" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "functools", "gc", "math", "os", "random", "time", "contextlib.contextmanager", "copy.deepcopy", "datetime.datetime", "pathlib.Path", "typing.Any", "typing.Dict", "typing.Union", "numpy", "torch", "torch.distributed", "torch.nn", "torch.nn.functional", "ultralytics.__version__", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.DEFAULT_CFG_KEYS", "ultralytics.utils.LOGGER", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.TORCHVISION_VERSION", "ultralytics.utils.WINDOWS", "ultralytics.utils.colorstr", "ultralytics.utils.checks.check_version", "ultralytics.utils.patches.torch_load", "ultralytics.utils.PERSISTENT_CACHE", "ultralytics.utils.autodevice.GPUInfo", "ultralytics.utils.benchmarks.ProfileModels", "thop", "thop", "cpuinfo" ], "chunk_id": "function_fuse_conv_and_bn_1390fc68" }, { "content": "def fuse_deconv_and_bn(deconv, bn):\n \"\"\"Fuse ConvTranspose2d() and BatchNorm2d() layers.\"\"\"\n fuseddconv = (\n nn.ConvTranspose2d(\n deconv.in_channels,\n deconv.out_channels,\n kernel_size=deconv.kernel_size,\n stride=deconv.stride,\n padding=deconv.padding,\n output_padding=deconv.output_padding,\n dilation=deconv.dilation,\n groups=deconv.groups,\n bias=True,\n )\n .requires_grad_(False)\n .to(deconv.weight.device)\n )\n\n # Prepare filters\n w_deconv = deconv.weight.view(deconv.out_channels, -1)\n w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))\n fuseddconv.weight.copy_(torch.mm(w_bn, w_deconv).view(fuseddconv.weight.shape))\n\n # Prepare spatial bias\n b_conv = torch.zeros(deconv.weight.shape[1], device=deconv.weight.device) if deconv.bias is None else deconv.bias\n b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))\n fuseddconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)\n\n return fuseddconv", "chunk_type": "function", "name": "fuse_deconv_and_bn", "file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py", "start_line": 286, "end_line": 314, "start_col": 0, "end_col": 21, "parent_name": null, "docstring": "Fuse ConvTranspose2d() and BatchNorm2d() layers.", "parameters": [ "deconv", "bn" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "functools", "gc", "math", "os", "random", "time", "contextlib.contextmanager", "copy.deepcopy", "datetime.datetime", "pathlib.Path", "typing.Any", "typing.Dict", "typing.Union", "numpy", "torch", "torch.distributed", "torch.nn", "torch.nn.functional", "ultralytics.__version__", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.DEFAULT_CFG_KEYS", "ultralytics.utils.LOGGER", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.TORCHVISION_VERSION", "ultralytics.utils.WINDOWS", "ultralytics.utils.colorstr", "ultralytics.utils.checks.check_version", "ultralytics.utils.patches.torch_load", "ultralytics.utils.PERSISTENT_CACHE", "ultralytics.utils.autodevice.GPUInfo", "ultralytics.utils.benchmarks.ProfileModels", "thop", "thop", "cpuinfo" ], "chunk_id": "function_fuse_deconv_and_bn_2d6caedb" }, { "content": "def model_info(model, detailed=False, verbose=True, imgsz=640):\n \"\"\"\n Print and return detailed model information layer by layer.\n\n Args:\n model (nn.Module): Model to analyze.\n detailed (bool, optional): Whether to print detailed layer information.\n verbose (bool, optional): Whether to print model information.\n imgsz (int | list, optional): Input image size.\n\n Returns:\n n_l (int): Number of layers.\n n_p (int): Number of parameters.\n n_g (int): Number of gradients.\n flops (float): GFLOPs.\n \"\"\"\n if not verbose:\n return\n n_p = get_num_params(model) # number of parameters\n n_g = get_num_gradients(model) # number of gradients\n layers = __import__(\"collections\").OrderedDict((n, m) for n, m in model.named_modules() if len(m._modules) == 0)\n n_l = len(layers) # number of layers\n if detailed:\n h = f\"{'layer':>5}{'name':>40}{'type':>20}{'gradient':>10}{'parameters':>12}{'shape':>20}{'mu':>10}{'sigma':>10}\"\n LOGGER.info(h)\n for i, (mn, m) in enumerate(layers.items()):\n mn = mn.replace(\"module_list.\", \"\")\n mt = m.__class__.__name__\n if len(m._parameters):\n for pn, p in m.named_parameters():\n LOGGER.info(\n f\"{i:>5g}{f'{mn}.{pn}':>40}{mt:>20}{p.requires_grad!r:>10}{p.numel():>12g}{str(list(p.shape)):>20}{p.mean():>10.3g}{p.std():>10.3g}{str(p.dtype).replace('torch.', ''):>15}\"\n )\n else: # layers with no learnable params\n LOGGER.info(f\"{i:>5g}{mn:>40}{mt:>20}{False!r:>10}{0:>12g}{str([]):>20}{'-':>10}{'-':>10}{'-':>15}\")\n\n flops = get_flops(model, imgsz) # imgsz may be int or list, i.e. imgsz=640 or imgsz=[640, 320]\n fused = \" (fused)\" if getattr(model, \"is_fused\", lambda: False)() else \"\"\n fs = f\", {flops:.1f} GFLOPs\" if flops else \"\"\n yaml_file = getattr(model, \"yaml_file\", \"\") or getattr(model, \"yaml\", {}).get(\"yaml_file\", \"\")\n model_name = Path(yaml_file).stem.replace(\"yolo\", \"YOLO\") or \"Model\"\n LOGGER.info(f\"{model_name} summary{fused}: {n_l:,} layers, {n_p:,} parameters, {n_g:,} gradients{fs}\")\n return n_l, n_p, n_g, flops", "chunk_type": "function", "name": "model_info", "file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py", "start_line": 317, "end_line": 359, "start_col": 0, "end_col": 31, "parent_name": null, "docstring": "Print and return detailed model information layer by layer.\n\nArgs:\n model (nn.Module): Model to analyze.\n detailed (bool, optional): Whether to print detailed layer information.\n verbose (bool, optional): Whether to print model information.\n imgsz (int | list, optional): Input image size.\n\nReturns:\n n_l (int): Number of layers.\n n_p (int): Number of parameters.\n n_g (int): Number of gradients.\n flops (float): GFLOPs.", "parameters": [ "model", "detailed", "verbose", "imgsz" ], "return_type": null, "decorators": [], "complexity_score": 7, "dependencies": [ "functools", "gc", "math", "os", "random", "time", "contextlib.contextmanager", "copy.deepcopy", "datetime.datetime", "pathlib.Path", "typing.Any", "typing.Dict", "typing.Union", "numpy", "torch", "torch.distributed", "torch.nn", "torch.nn.functional", "ultralytics.__version__", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.DEFAULT_CFG_KEYS", "ultralytics.utils.LOGGER", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.TORCHVISION_VERSION", "ultralytics.utils.WINDOWS", "ultralytics.utils.colorstr", "ultralytics.utils.checks.check_version", "ultralytics.utils.patches.torch_load", "ultralytics.utils.PERSISTENT_CACHE", "ultralytics.utils.autodevice.GPUInfo", "ultralytics.utils.benchmarks.ProfileModels", "thop", "thop", "cpuinfo" ], "chunk_id": "function_model_info_88b9a323" }, { "content": "def get_num_params(model):\n \"\"\"Return the total number of parameters in a YOLO model.\"\"\"\n return sum(x.numel() for x in model.parameters())", "chunk_type": "function", "name": "get_num_params", "file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py", "start_line": 362, "end_line": 364, "start_col": 0, "end_col": 53, "parent_name": null, "docstring": "Return the total number of parameters in a YOLO model.", "parameters": [ "model" ], "return_type": null, "decorators": [], "complexity_score": 2, "dependencies": [ "functools", "gc", "math", "os", "random", "time", "contextlib.contextmanager", "copy.deepcopy", "datetime.datetime", "pathlib.Path", "typing.Any", "typing.Dict", "typing.Union", "numpy", "torch", "torch.distributed", "torch.nn", "torch.nn.functional", "ultralytics.__version__", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.DEFAULT_CFG_KEYS", "ultralytics.utils.LOGGER", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.TORCHVISION_VERSION", "ultralytics.utils.WINDOWS", "ultralytics.utils.colorstr", "ultralytics.utils.checks.check_version", "ultralytics.utils.patches.torch_load", "ultralytics.utils.PERSISTENT_CACHE", "ultralytics.utils.autodevice.GPUInfo", "ultralytics.utils.benchmarks.ProfileModels", "thop", "thop", "cpuinfo" ], "chunk_id": "function_get_num_params_69f2ae44" }, { "content": "def get_num_gradients(model):\n \"\"\"Return the total number of parameters with gradients in a YOLO model.\"\"\"\n return sum(x.numel() for x in model.parameters() if x.requires_grad)", "chunk_type": "function", "name": "get_num_gradients", "file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py", "start_line": 367, "end_line": 369, "start_col": 0, "end_col": 72, "parent_name": null, "docstring": "Return the total number of parameters with gradients in a YOLO model.", "parameters": [ "model" ], "return_type": null, "decorators": [], "complexity_score": 2, "dependencies": [ "functools", "gc", "math", "os", "random", "time", "contextlib.contextmanager", "copy.deepcopy", "datetime.datetime", "pathlib.Path", "typing.Any", "typing.Dict", "typing.Union", "numpy", "torch", "torch.distributed", "torch.nn", "torch.nn.functional", "ultralytics.__version__", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.DEFAULT_CFG_KEYS", "ultralytics.utils.LOGGER", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.TORCHVISION_VERSION", "ultralytics.utils.WINDOWS", "ultralytics.utils.colorstr", "ultralytics.utils.checks.check_version", "ultralytics.utils.patches.torch_load", "ultralytics.utils.PERSISTENT_CACHE", "ultralytics.utils.autodevice.GPUInfo", "ultralytics.utils.benchmarks.ProfileModels", "thop", "thop", "cpuinfo" ], "chunk_id": "function_get_num_gradients_b2843e9c" }, { "content": "def model_info_for_loggers(trainer):\n \"\"\"\n Return model info dict with useful model information.\n\n Args:\n trainer (ultralytics.engine.trainer.BaseTrainer): The trainer object containing model and validation data.\n\n Returns:\n (dict): Dictionary containing model parameters, GFLOPs, and inference speeds.\n\n Examples:\n YOLOv8n info for loggers\n >>> results = {\n ... \"model/parameters\": 3151904,\n ... \"model/GFLOPs\": 8.746,\n ... \"model/speed_ONNX(ms)\": 41.244,\n ... \"model/speed_TensorRT(ms)\": 3.211,\n ... \"model/speed_PyTorch(ms)\": 18.755,\n ...}\n \"\"\"\n if trainer.args.profile: # profile ONNX and TensorRT times\n from ultralytics.utils.benchmarks import ProfileModels\n\n results = ProfileModels([trainer.last], device=trainer.device).run()[0]\n results.pop(\"model/name\")\n else: # only return PyTorch times from most recent validation\n results = {\n \"model/parameters\": get_num_params(trainer.model),\n \"model/GFLOPs\": round(get_flops(trainer.model), 3),\n }\n results[\"model/speed_PyTorch(ms)\"] = round(trainer.validator.speed[\"inference\"], 3)\n return results", "chunk_type": "function", "name": "model_info_for_loggers", "file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py", "start_line": 372, "end_line": 403, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": "Return model info dict with useful model information.\n\nArgs:\n trainer (ultralytics.engine.trainer.BaseTrainer): The trainer object containing model and validation data.\n\nReturns:\n (dict): Dictionary containing model parameters, GFLOPs, and inference speeds.\n\nExamples:\n YOLOv8n info for loggers\n >>> results = {\n ... \"model/parameters\": 3151904,\n ... \"model/GFLOPs\": 8.746,\n ... \"model/speed_ONNX(ms)\": 41.244,\n ... \"model/speed_TensorRT(ms)\": 3.211,\n ... \"model/speed_PyTorch(ms)\": 18.755,\n ...}", "parameters": [ "trainer" ], "return_type": null, "decorators": [], "complexity_score": 2, "dependencies": [ "functools", "gc", "math", "os", "random", "time", "contextlib.contextmanager", "copy.deepcopy", "datetime.datetime", "pathlib.Path", "typing.Any", "typing.Dict", "typing.Union", "numpy", "torch", "torch.distributed", "torch.nn", "torch.nn.functional", "ultralytics.__version__", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.DEFAULT_CFG_KEYS", "ultralytics.utils.LOGGER", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.TORCHVISION_VERSION", "ultralytics.utils.WINDOWS", "ultralytics.utils.colorstr", "ultralytics.utils.checks.check_version", "ultralytics.utils.patches.torch_load", "ultralytics.utils.PERSISTENT_CACHE", "ultralytics.utils.autodevice.GPUInfo", "ultralytics.utils.benchmarks.ProfileModels", "thop", "thop", "cpuinfo" ], "chunk_id": "function_model_info_for_loggers_613e00b0" }, { "content": "def get_flops(model, imgsz=640):\n \"\"\"\n Calculate FLOPs (floating point operations) for a model in billions.\n\n Attempts two calculation methods: first with a stride-based tensor for efficiency,\n then falls back to full image size if needed (e.g., for RTDETR models). Returns 0.0\n if thop library is unavailable or calculation fails.\n\n Args:\n model (nn.Module): The model to calculate FLOPs for.\n imgsz (int | list, optional): Input image size.\n\n Returns:\n (float): The model FLOPs in billions.\n \"\"\"\n try:\n import thop\n except ImportError:\n thop = None # conda support without 'ultralytics-thop' installed\n\n if not thop:\n return 0.0 # if not installed return 0.0 GFLOPs\n\n try:\n model = de_parallel(model)\n p = next(model.parameters())\n if not isinstance(imgsz, list):\n imgsz = [imgsz, imgsz] # expand if int/float\n try:\n # Method 1: Use stride-based input tensor\n stride = max(int(model.stride.max()), 32) if hasattr(model, \"stride\") else 32 # max stride\n im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format\n flops = thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1e9 * 2 # stride GFLOPs\n return flops * imgsz[0] / stride * imgsz[1] / stride # imgsz GFLOPs\n except Exception:\n # Method 2: Use actual image size (required for RTDETR models)\n im = torch.empty((1, p.shape[1], *imgsz), device=p.device) # input image in BCHW format\n return thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1e9 * 2 # imgsz GFLOPs\n except Exception:\n return 0.0", "chunk_type": "function", "name": "get_flops", "file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py", "start_line": 406, "end_line": 445, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": "Calculate FLOPs (floating point operations) for a model in billions.\n\nAttempts two calculation methods: first with a stride-based tensor for efficiency,\nthen falls back to full image size if needed (e.g., for RTDETR models). Returns 0.0\nif thop library is unavailable or calculation fails.\n\nArgs:\n model (nn.Module): The model to calculate FLOPs for.\n imgsz (int | list, optional): Input image size.\n\nReturns:\n (float): The model FLOPs in billions.", "parameters": [ "model", "imgsz" ], "return_type": null, "decorators": [], "complexity_score": 6, "dependencies": [ "functools", "gc", "math", "os", "random", "time", "contextlib.contextmanager", "copy.deepcopy", "datetime.datetime", "pathlib.Path", "typing.Any", "typing.Dict", "typing.Union", "numpy", "torch", "torch.distributed", "torch.nn", "torch.nn.functional", "ultralytics.__version__", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.DEFAULT_CFG_KEYS", "ultralytics.utils.LOGGER", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.TORCHVISION_VERSION", "ultralytics.utils.WINDOWS", "ultralytics.utils.colorstr", "ultralytics.utils.checks.check_version", "ultralytics.utils.patches.torch_load", "ultralytics.utils.PERSISTENT_CACHE", "ultralytics.utils.autodevice.GPUInfo", "ultralytics.utils.benchmarks.ProfileModels", "thop", "thop", "cpuinfo" ], "chunk_id": "function_get_flops_17bd5430" }, { "content": "def get_flops_with_torch_profiler(model, imgsz=640):\n \"\"\"\n Compute model FLOPs using torch profiler (alternative to thop package, but 2-10x slower).\n\n Args:\n model (nn.Module): The model to calculate FLOPs for.\n imgsz (int | list, optional): Input image size.\n\n Returns:\n (float): The model's FLOPs in billions.\n \"\"\"\n if not TORCH_2_0: # torch profiler implemented in torch>=2.0\n return 0.0\n model = de_parallel(model)\n p = next(model.parameters())\n if not isinstance(imgsz, list):\n imgsz = [imgsz, imgsz] # expand if int/float\n try:\n # Use stride size for input tensor\n stride = (max(int(model.stride.max()), 32) if hasattr(model, \"stride\") else 32) * 2 # max stride\n im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format\n with torch.profiler.profile(with_flops=True) as prof:\n model(im)\n flops = sum(x.flops for x in prof.key_averages()) / 1e9\n flops = flops * imgsz[0] / stride * imgsz[1] / stride # 640x640 GFLOPs\n except Exception:\n # Use actual image size for input tensor (i.e. required for RTDETR models)\n im = torch.empty((1, p.shape[1], *imgsz), device=p.device) # input image in BCHW format\n with torch.profiler.profile(with_flops=True) as prof:\n model(im)\n flops = sum(x.flops for x in prof.key_averages()) / 1e9\n return flops", "chunk_type": "function", "name": "get_flops_with_torch_profiler", "file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py", "start_line": 448, "end_line": 479, "start_col": 0, "end_col": 16, "parent_name": null, "docstring": "Compute model FLOPs using torch profiler (alternative to thop package, but 2-10x slower).\n\nArgs:\n model (nn.Module): The model to calculate FLOPs for.\n imgsz (int | list, optional): Input image size.\n\nReturns:\n (float): The model's FLOPs in billions.", "parameters": [ "model", "imgsz" ], "return_type": null, "decorators": [], "complexity_score": 6, "dependencies": [ "functools", "gc", "math", "os", "random", "time", "contextlib.contextmanager", "copy.deepcopy", "datetime.datetime", "pathlib.Path", "typing.Any", "typing.Dict", "typing.Union", "numpy", "torch", "torch.distributed", "torch.nn", "torch.nn.functional", "ultralytics.__version__", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.DEFAULT_CFG_KEYS", "ultralytics.utils.LOGGER", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.TORCHVISION_VERSION", "ultralytics.utils.WINDOWS", "ultralytics.utils.colorstr", "ultralytics.utils.checks.check_version", "ultralytics.utils.patches.torch_load", "ultralytics.utils.PERSISTENT_CACHE", "ultralytics.utils.autodevice.GPUInfo", "ultralytics.utils.benchmarks.ProfileModels", "thop", "thop", "cpuinfo" ], "chunk_id": "function_get_flops_with_torch_profiler_1c92f7a3" }, { "content": "def initialize_weights(model):\n \"\"\"Initialize model weights to random values.\"\"\"\n for m in model.modules():\n t = type(m)\n if t is nn.Conv2d:\n pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n elif t is nn.BatchNorm2d:\n m.eps = 1e-3\n m.momentum = 0.03\n elif t in {nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU}:\n m.inplace = True", "chunk_type": "function", "name": "initialize_weights", "file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py", "start_line": 482, "end_line": 492, "start_col": 0, "end_col": 28, "parent_name": null, "docstring": "Initialize model weights to random values.", "parameters": [ "model" ], "return_type": null, "decorators": [], "complexity_score": 5, "dependencies": [ "functools", "gc", "math", "os", "random", "time", "contextlib.contextmanager", "copy.deepcopy", "datetime.datetime", "pathlib.Path", "typing.Any", "typing.Dict", "typing.Union", "numpy", "torch", "torch.distributed", "torch.nn", "torch.nn.functional", "ultralytics.__version__", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.DEFAULT_CFG_KEYS", "ultralytics.utils.LOGGER", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.TORCHVISION_VERSION", "ultralytics.utils.WINDOWS", "ultralytics.utils.colorstr", "ultralytics.utils.checks.check_version", "ultralytics.utils.patches.torch_load", "ultralytics.utils.PERSISTENT_CACHE", "ultralytics.utils.autodevice.GPUInfo", "ultralytics.utils.benchmarks.ProfileModels", "thop", "thop", "cpuinfo" ], "chunk_id": "function_initialize_weights_c22e7ae4" }, { "content": "def scale_img(img, ratio=1.0, same_shape=False, gs=32):\n \"\"\"\n Scale and pad an image tensor, optionally maintaining aspect ratio and padding to gs multiple.\n\n Args:\n img (torch.Tensor): Input image tensor.\n ratio (float, optional): Scaling ratio.\n same_shape (bool, optional): Whether to maintain the same shape.\n gs (int, optional): Grid size for padding.\n\n Returns:\n (torch.Tensor): Scaled and padded image tensor.\n \"\"\"\n if ratio == 1.0:\n return img\n h, w = img.shape[2:]\n s = (int(h * ratio), int(w * ratio)) # new size\n img = F.interpolate(img, size=s, mode=\"bilinear\", align_corners=False) # resize\n if not same_shape: # pad/crop img\n h, w = (math.ceil(x * ratio / gs) * gs for x in (h, w))\n return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean", "chunk_type": "function", "name": "scale_img", "file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py", "start_line": 495, "end_line": 515, "start_col": 0, "end_col": 62, "parent_name": null, "docstring": "Scale and pad an image tensor, optionally maintaining aspect ratio and padding to gs multiple.\n\nArgs:\n img (torch.Tensor): Input image tensor.\n ratio (float, optional): Scaling ratio.\n same_shape (bool, optional): Whether to maintain the same shape.\n gs (int, optional): Grid size for padding.\n\nReturns:\n (torch.Tensor): Scaled and padded image tensor.", "parameters": [ "img", "ratio", "same_shape", "gs" ], "return_type": null, "decorators": [], "complexity_score": 4, "dependencies": [ "functools", "gc", "math", "os", "random", "time", "contextlib.contextmanager", "copy.deepcopy", "datetime.datetime", "pathlib.Path", "typing.Any", "typing.Dict", "typing.Union", "numpy", "torch", "torch.distributed", "torch.nn", "torch.nn.functional", "ultralytics.__version__", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.DEFAULT_CFG_KEYS", "ultralytics.utils.LOGGER", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.TORCHVISION_VERSION", "ultralytics.utils.WINDOWS", "ultralytics.utils.colorstr", "ultralytics.utils.checks.check_version", "ultralytics.utils.patches.torch_load", "ultralytics.utils.PERSISTENT_CACHE", "ultralytics.utils.autodevice.GPUInfo", "ultralytics.utils.benchmarks.ProfileModels", "thop", "thop", "cpuinfo" ], "chunk_id": "function_scale_img_f870b9db" }, { "content": "def copy_attr(a, b, include=(), exclude=()):\n \"\"\"\n Copy attributes from object 'b' to object 'a', with options to include/exclude certain attributes.\n\n Args:\n a (Any): Destination object to copy attributes to.\n b (Any): Source object to copy attributes from.\n include (tuple, optional): Attributes to include. If empty, all attributes are included.\n exclude (tuple, optional): Attributes to exclude.\n \"\"\"\n for k, v in b.__dict__.items():\n if (len(include) and k not in include) or k.startswith(\"_\") or k in exclude:\n continue\n else:\n setattr(a, k, v)", "chunk_type": "function", "name": "copy_attr", "file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py", "start_line": 518, "end_line": 532, "start_col": 0, "end_col": 28, "parent_name": null, "docstring": "Copy attributes from object 'b' to object 'a', with options to include/exclude certain attributes.\n\nArgs:\n a (Any): Destination object to copy attributes to.\n b (Any): Source object to copy attributes from.\n include (tuple, optional): Attributes to include. If empty, all attributes are included.\n exclude (tuple, optional): Attributes to exclude.", "parameters": [ "a", "b", "include", "exclude" ], "return_type": null, "decorators": [], "complexity_score": 3, "dependencies": [ "functools", "gc", "math", "os", "random", "time", "contextlib.contextmanager", "copy.deepcopy", "datetime.datetime", "pathlib.Path", "typing.Any", "typing.Dict", "typing.Union", "numpy", "torch", "torch.distributed", "torch.nn", "torch.nn.functional", "ultralytics.__version__", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.DEFAULT_CFG_KEYS", "ultralytics.utils.LOGGER", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.TORCHVISION_VERSION", "ultralytics.utils.WINDOWS", "ultralytics.utils.colorstr", "ultralytics.utils.checks.check_version", "ultralytics.utils.patches.torch_load", "ultralytics.utils.PERSISTENT_CACHE", "ultralytics.utils.autodevice.GPUInfo", "ultralytics.utils.benchmarks.ProfileModels", "thop", "thop", "cpuinfo" ], "chunk_id": "function_copy_attr_7422bd0c" }, { "content": "def get_latest_opset():\n \"\"\"\n Return the second-most recent ONNX opset version supported by this version of PyTorch, adjusted for maturity.\n\n Returns:\n (int): The ONNX opset version.\n \"\"\"\n if TORCH_1_13:\n # If the PyTorch>=1.13, dynamically compute the latest opset minus one using 'symbolic_opset'\n return max(int(k[14:]) for k in vars(torch.onnx) if \"symbolic_opset\" in k) - 1\n # Otherwise for PyTorch<=1.12 return the corresponding predefined opset\n version = torch.onnx.producer_version.rsplit(\".\", 1)[0] # i.e. '2.3'\n return {\"1.12\": 15, \"1.11\": 14, \"1.10\": 13, \"1.9\": 12, \"1.8\": 12}.get(version, 12)", "chunk_type": "function", "name": "get_latest_opset", "file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py", "start_line": 535, "end_line": 547, "start_col": 0, "end_col": 86, "parent_name": null, "docstring": "Return the second-most recent ONNX opset version supported by this version of PyTorch, adjusted for maturity.\n\nReturns:\n (int): The ONNX opset version.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 3, "dependencies": [ "functools", "gc", "math", "os", "random", "time", "contextlib.contextmanager", "copy.deepcopy", "datetime.datetime", "pathlib.Path", "typing.Any", "typing.Dict", "typing.Union", "numpy", "torch", "torch.distributed", "torch.nn", "torch.nn.functional", "ultralytics.__version__", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.DEFAULT_CFG_KEYS", "ultralytics.utils.LOGGER", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.TORCHVISION_VERSION", "ultralytics.utils.WINDOWS", "ultralytics.utils.colorstr", "ultralytics.utils.checks.check_version", "ultralytics.utils.patches.torch_load", "ultralytics.utils.PERSISTENT_CACHE", "ultralytics.utils.autodevice.GPUInfo", "ultralytics.utils.benchmarks.ProfileModels", "thop", "thop", "cpuinfo" ], "chunk_id": "function_get_latest_opset_6f3df7f2" }, { "content": "def intersect_dicts(da, db, exclude=()):\n \"\"\"\n Return a dictionary of intersecting keys with matching shapes, excluding 'exclude' keys, using da values.\n\n Args:\n da (dict): First dictionary.\n db (dict): Second dictionary.\n exclude (tuple, optional): Keys to exclude.\n\n Returns:\n (dict): Dictionary of intersecting keys with matching shapes.\n \"\"\"\n return {k: v for k, v in da.items() if k in db and all(x not in k for x in exclude) and v.shape == db[k].shape}", "chunk_type": "function", "name": "intersect_dicts", "file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py", "start_line": 550, "end_line": 562, "start_col": 0, "end_col": 115, "parent_name": null, "docstring": "Return a dictionary of intersecting keys with matching shapes, excluding 'exclude' keys, using da values.\n\nArgs:\n da (dict): First dictionary.\n db (dict): Second dictionary.\n exclude (tuple, optional): Keys to exclude.\n\nReturns:\n (dict): Dictionary of intersecting keys with matching shapes.", "parameters": [ "da", "db", "exclude" ], "return_type": null, "decorators": [], "complexity_score": 3, "dependencies": [ "functools", "gc", "math", "os", "random", "time", "contextlib.contextmanager", "copy.deepcopy", "datetime.datetime", "pathlib.Path", "typing.Any", "typing.Dict", "typing.Union", "numpy", "torch", "torch.distributed", "torch.nn", "torch.nn.functional", "ultralytics.__version__", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.DEFAULT_CFG_KEYS", "ultralytics.utils.LOGGER", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.TORCHVISION_VERSION", "ultralytics.utils.WINDOWS", "ultralytics.utils.colorstr", "ultralytics.utils.checks.check_version", "ultralytics.utils.patches.torch_load", "ultralytics.utils.PERSISTENT_CACHE", "ultralytics.utils.autodevice.GPUInfo", "ultralytics.utils.benchmarks.ProfileModels", "thop", "thop", "cpuinfo" ], "chunk_id": "function_intersect_dicts_de5ef473" }, { "content": "def is_parallel(model):\n \"\"\"\n Return True if model is of type DP or DDP.\n\n Args:\n model (nn.Module): Model to check.\n\n Returns:\n (bool): True if model is DataParallel or DistributedDataParallel.\n \"\"\"\n return isinstance(model, (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel))", "chunk_type": "function", "name": "is_parallel", "file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py", "start_line": 565, "end_line": 575, "start_col": 0, "end_col": 93, "parent_name": null, "docstring": "Return True if model is of type DP or DDP.\n\nArgs:\n model (nn.Module): Model to check.\n\nReturns:\n (bool): True if model is DataParallel or DistributedDataParallel.", "parameters": [ "model" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "functools", "gc", "math", "os", "random", "time", "contextlib.contextmanager", "copy.deepcopy", "datetime.datetime", "pathlib.Path", "typing.Any", "typing.Dict", "typing.Union", "numpy", "torch", "torch.distributed", "torch.nn", "torch.nn.functional", "ultralytics.__version__", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.DEFAULT_CFG_KEYS", "ultralytics.utils.LOGGER", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.TORCHVISION_VERSION", "ultralytics.utils.WINDOWS", "ultralytics.utils.colorstr", "ultralytics.utils.checks.check_version", "ultralytics.utils.patches.torch_load", "ultralytics.utils.PERSISTENT_CACHE", "ultralytics.utils.autodevice.GPUInfo", "ultralytics.utils.benchmarks.ProfileModels", "thop", "thop", "cpuinfo" ], "chunk_id": "function_is_parallel_a6f44a0b" }, { "content": "def de_parallel(model):\n \"\"\"\n De-parallelize a model: return single-GPU model if model is of type DP or DDP.\n\n Args:\n model (nn.Module): Model to de-parallelize.\n\n Returns:\n (nn.Module): De-parallelized model.\n \"\"\"\n return model.module if is_parallel(model) else model", "chunk_type": "function", "name": "de_parallel", "file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py", "start_line": 578, "end_line": 588, "start_col": 0, "end_col": 56, "parent_name": null, "docstring": "De-parallelize a model: return single-GPU model if model is of type DP or DDP.\n\nArgs:\n model (nn.Module): Model to de-parallelize.\n\nReturns:\n (nn.Module): De-parallelized model.", "parameters": [ "model" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "functools", "gc", "math", "os", "random", "time", "contextlib.contextmanager", "copy.deepcopy", "datetime.datetime", "pathlib.Path", "typing.Any", "typing.Dict", "typing.Union", "numpy", "torch", "torch.distributed", "torch.nn", "torch.nn.functional", "ultralytics.__version__", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.DEFAULT_CFG_KEYS", "ultralytics.utils.LOGGER", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.TORCHVISION_VERSION", "ultralytics.utils.WINDOWS", "ultralytics.utils.colorstr", "ultralytics.utils.checks.check_version", "ultralytics.utils.patches.torch_load", "ultralytics.utils.PERSISTENT_CACHE", "ultralytics.utils.autodevice.GPUInfo", "ultralytics.utils.benchmarks.ProfileModels", "thop", "thop", "cpuinfo" ], "chunk_id": "function_de_parallel_0b343ee4" }, { "content": "def one_cycle(y1=0.0, y2=1.0, steps=100):\n \"\"\"\n Return a lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf.\n\n Args:\n y1 (float, optional): Initial value.\n y2 (float, optional): Final value.\n steps (int, optional): Number of steps.\n\n Returns:\n (function): Lambda function for computing the sinusoidal ramp.\n \"\"\"\n return lambda x: max((1 - math.cos(x * math.pi / steps)) / 2, 0) * (y2 - y1) + y1", "chunk_type": "function", "name": "one_cycle", "file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py", "start_line": 591, "end_line": 603, "start_col": 0, "end_col": 85, "parent_name": null, "docstring": "Return a lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf.\n\nArgs:\n y1 (float, optional): Initial value.\n y2 (float, optional): Final value.\n steps (int, optional): Number of steps.\n\nReturns:\n (function): Lambda function for computing the sinusoidal ramp.", "parameters": [ "y1", "y2", "steps" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "functools", "gc", "math", "os", "random", "time", "contextlib.contextmanager", "copy.deepcopy", "datetime.datetime", "pathlib.Path", "typing.Any", "typing.Dict", "typing.Union", "numpy", "torch", "torch.distributed", "torch.nn", "torch.nn.functional", "ultralytics.__version__", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.DEFAULT_CFG_KEYS", "ultralytics.utils.LOGGER", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.TORCHVISION_VERSION", "ultralytics.utils.WINDOWS", "ultralytics.utils.colorstr", "ultralytics.utils.checks.check_version", "ultralytics.utils.patches.torch_load", "ultralytics.utils.PERSISTENT_CACHE", "ultralytics.utils.autodevice.GPUInfo", "ultralytics.utils.benchmarks.ProfileModels", "thop", "thop", "cpuinfo" ], "chunk_id": "function_one_cycle_69bc49bf" }, { "content": "def init_seeds(seed=0, deterministic=False):\n \"\"\"\n Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html.\n\n Args:\n seed (int, optional): Random seed.\n deterministic (bool, optional): Whether to set deterministic algorithms.\n \"\"\"\n random.seed(seed)\n np.random.seed(seed)\n torch.manual_seed(seed)\n torch.cuda.manual_seed(seed)\n torch.cuda.manual_seed_all(seed) # for Multi-GPU, exception safe\n # torch.backends.cudnn.benchmark = True # AutoBatch problem https://github.com/ultralytics/yolov5/issues/9287\n if deterministic:\n if TORCH_2_0:\n torch.use_deterministic_algorithms(True, warn_only=True) # warn if deterministic is not possible\n torch.backends.cudnn.deterministic = True\n os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] = \":4096:8\"\n os.environ[\"PYTHONHASHSEED\"] = str(seed)\n else:\n LOGGER.warning(\"Upgrade to torch>=2.0.0 for deterministic training.\")\n else:\n unset_deterministic()", "chunk_type": "function", "name": "init_seeds", "file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py", "start_line": 606, "end_line": 629, "start_col": 0, "end_col": 29, "parent_name": null, "docstring": "Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html.\n\nArgs:\n seed (int, optional): Random seed.\n deterministic (bool, optional): Whether to set deterministic algorithms.", "parameters": [ "seed", "deterministic" ], "return_type": null, "decorators": [], "complexity_score": 3, "dependencies": [ "functools", "gc", "math", "os", "random", "time", "contextlib.contextmanager", "copy.deepcopy", "datetime.datetime", "pathlib.Path", "typing.Any", "typing.Dict", "typing.Union", "numpy", "torch", "torch.distributed", "torch.nn", "torch.nn.functional", "ultralytics.__version__", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.DEFAULT_CFG_KEYS", "ultralytics.utils.LOGGER", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.TORCHVISION_VERSION", "ultralytics.utils.WINDOWS", "ultralytics.utils.colorstr", "ultralytics.utils.checks.check_version", "ultralytics.utils.patches.torch_load", "ultralytics.utils.PERSISTENT_CACHE", "ultralytics.utils.autodevice.GPUInfo", "ultralytics.utils.benchmarks.ProfileModels", "thop", "thop", "cpuinfo" ], "chunk_id": "function_init_seeds_12265020" }, { "content": "def unset_deterministic():\n \"\"\"Unset all the configurations applied for deterministic training.\"\"\"\n torch.use_deterministic_algorithms(False)\n torch.backends.cudnn.deterministic = False\n os.environ.pop(\"CUBLAS_WORKSPACE_CONFIG\", None)\n os.environ.pop(\"PYTHONHASHSEED\", None)", "chunk_type": "function", "name": "unset_deterministic", "file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py", "start_line": 632, "end_line": 637, "start_col": 0, "end_col": 42, "parent_name": null, "docstring": "Unset all the configurations applied for deterministic training.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "functools", "gc", "math", "os", "random", "time", "contextlib.contextmanager", "copy.deepcopy", "datetime.datetime", "pathlib.Path", "typing.Any", "typing.Dict", "typing.Union", "numpy", "torch", "torch.distributed", "torch.nn", "torch.nn.functional", "ultralytics.__version__", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.DEFAULT_CFG_KEYS", "ultralytics.utils.LOGGER", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.TORCHVISION_VERSION", "ultralytics.utils.WINDOWS", "ultralytics.utils.colorstr", "ultralytics.utils.checks.check_version", "ultralytics.utils.patches.torch_load", "ultralytics.utils.PERSISTENT_CACHE", "ultralytics.utils.autodevice.GPUInfo", "ultralytics.utils.benchmarks.ProfileModels", "thop", "thop", "cpuinfo" ], "chunk_id": "function_unset_deterministic_e2ce1747" }, { "content": "class ModelEMA:\n \"\"\"\n Updated Exponential Moving Average (EMA) implementation.\n\n Keeps a moving average of everything in the model state_dict (parameters and buffers).\n For EMA details see References.\n\n To disable EMA set the `enabled` attribute to `False`.\n\n Attributes:\n ema (nn.Module): Copy of the model in evaluation mode.\n updates (int): Number of EMA updates.\n decay (function): Decay function that determines the EMA weight.\n enabled (bool): Whether EMA is enabled.\n\n References:\n - https://github.com/rwightman/pytorch-image-models\n - https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage\n \"\"\"\n\n def __init__(self, model, decay=0.9999, tau=2000, updates=0):\n \"\"\"\n Initialize EMA for 'model' with given arguments.\n\n Args:\n model (nn.Module): Model to create EMA for.\n decay (float, optional): Maximum EMA decay rate.\n tau (int, optional): EMA decay time constant.\n updates (int, optional): Initial number of updates.\n \"\"\"\n self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA\n self.updates = updates # number of EMA updates\n self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)\n for p in self.ema.parameters():\n p.requires_grad_(False)\n self.enabled = True\n\n def update(self, model):\n \"\"\"\n Update EMA parameters.\n\n Args:\n model (nn.Module): Model to update EMA from.\n \"\"\"\n if self.enabled:\n self.updates += 1\n d = self.decay(self.updates)\n\n msd = de_parallel(model).state_dict() # model state_dict\n for k, v in self.ema.state_dict().items():\n if v.dtype.is_floating_point: # true for FP16 and FP32\n v *= d\n v += (1 - d) * msd[k].detach()\n # assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype}, model {msd[k].dtype}'\n\n def update_attr(self, model, include=(), exclude=(\"process_group\", \"reducer\")):\n \"\"\"\n Update attributes and save stripped model with optimizer removed.\n\n Args:\n model (nn.Module): Model to update attributes from.\n include (tuple, optional): Attributes to include.\n exclude (tuple, optional): Attributes to exclude.\n \"\"\"\n if self.enabled:\n copy_attr(self.ema, model, include, exclude)", "chunk_type": "class", "name": "ModelEMA", "file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py", "start_line": 640, "end_line": 705, "start_col": 0, "end_col": 56, "parent_name": null, "docstring": "Updated Exponential Moving Average (EMA) implementation.\n\nKeeps a moving average of everything in the model state_dict (parameters and buffers).\nFor EMA details see References.\n\nTo disable EMA set the `enabled` attribute to `False`.\n\nAttributes:\n ema (nn.Module): Copy of the model in evaluation mode.\n updates (int): Number of EMA updates.\n decay (function): Decay function that determines the EMA weight.\n enabled (bool): Whether EMA is enabled.\n\nReferences:\n - https://github.com/rwightman/pytorch-image-models\n - https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "functools", "gc", "math", "os", "random", "time", "contextlib.contextmanager", "copy.deepcopy", "datetime.datetime", "pathlib.Path", "typing.Any", "typing.Dict", "typing.Union", "numpy", "torch", "torch.distributed", "torch.nn", "torch.nn.functional", "ultralytics.__version__", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.DEFAULT_CFG_KEYS", "ultralytics.utils.LOGGER", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.TORCHVISION_VERSION", "ultralytics.utils.WINDOWS", "ultralytics.utils.colorstr", "ultralytics.utils.checks.check_version", "ultralytics.utils.patches.torch_load", "ultralytics.utils.PERSISTENT_CACHE", "ultralytics.utils.autodevice.GPUInfo", "ultralytics.utils.benchmarks.ProfileModels", "thop", "thop", "cpuinfo" ], "chunk_id": "class_ModelEMA_df773054" }, { "content": "def strip_optimizer(f: Union[str, Path] = \"best.pt\", s: str = \"\", updates: Dict[str, Any] = None) -> Dict[str, Any]:\n \"\"\"\n Strip optimizer from 'f' to finalize training, optionally save as 's'.\n\n Args:\n f (str | Path): File path to model to strip the optimizer from.\n s (str, optional): File path to save the model with stripped optimizer to. If not provided, 'f' will be\n overwritten.\n updates (dict, optional): A dictionary of updates to overlay onto the checkpoint before saving.\n\n Returns:\n (dict): The combined checkpoint dictionary.\n\n Examples:\n >>> from pathlib import Path\n >>> from ultralytics.utils.torch_utils import strip_optimizer\n >>> for f in Path(\"path/to/model/checkpoints\").rglob(\"*.pt\"):\n >>> strip_optimizer(f)\n \"\"\"\n try:\n x = torch_load(f, map_location=torch.device(\"cpu\"))\n assert isinstance(x, dict), \"checkpoint is not a Python dictionary\"\n assert \"model\" in x, \"'model' missing from checkpoint\"\n except Exception as e:\n LOGGER.warning(f\"Skipping {f}, not a valid Ultralytics model: {e}\")\n return {}\n\n metadata = {\n \"date\": datetime.now().isoformat(),\n \"version\": __version__,\n \"license\": \"AGPL-3.0 License (https://ultralytics.com/license)\",\n \"docs\": \"https://docs.ultralytics.com\",\n }\n\n # Update model\n if x.get(\"ema\"):\n x[\"model\"] = x[\"ema\"] # replace model with EMA\n if hasattr(x[\"model\"], \"args\"):\n x[\"model\"].args = dict(x[\"model\"].args) # convert from IterableSimpleNamespace to dict\n if hasattr(x[\"model\"], \"criterion\"):\n x[\"model\"].criterion = None # strip loss criterion\n x[\"model\"].half() # to FP16\n for p in x[\"model\"].parameters():\n p.requires_grad = False\n\n # Update other keys\n args = {**DEFAULT_CFG_DICT, **x.get(\"train_args\", {})} # combine args\n for k in \"optimizer\", \"best_fitness\", \"ema\", \"updates\": # keys\n x[k] = None\n x[\"epoch\"] = -1\n x[\"train_args\"] = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # strip non-default keys\n # x['model'].args = x['train_args']\n\n # Save\n combined = {**metadata, **x, **(updates or {})}\n torch.save(combined, s or f) # combine dicts (prefer to the right)\n mb = os.path.getsize(s or f) / 1e6 # file size\n LOGGER.info(f\"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB\")\n return combined", "chunk_type": "function", "name": "strip_optimizer", "file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py", "start_line": 708, "end_line": 766, "start_col": 0, "end_col": 19, "parent_name": null, "docstring": "Strip optimizer from 'f' to finalize training, optionally save as 's'.\n\nArgs:\n f (str | Path): File path to model to strip the optimizer from.\n s (str, optional): File path to save the model with stripped optimizer to. If not provided, 'f' will be\n overwritten.\n updates (dict, optional): A dictionary of updates to overlay onto the checkpoint before saving.\n\nReturns:\n (dict): The combined checkpoint dictionary.\n\nExamples:\n >>> from pathlib import Path\n >>> from ultralytics.utils.torch_utils import strip_optimizer\n >>> for f in Path(\"path/to/model/checkpoints\").rglob(\"*.pt\"):\n >>> strip_optimizer(f)", "parameters": [ "f: Union[str, Path]", "s: str", "updates: Dict[str, Any]" ], "return_type": "Dict[str, Any]", "decorators": [], "complexity_score": 8, "dependencies": [ "functools", "gc", "math", "os", "random", "time", "contextlib.contextmanager", "copy.deepcopy", "datetime.datetime", "pathlib.Path", "typing.Any", "typing.Dict", "typing.Union", "numpy", "torch", "torch.distributed", "torch.nn", "torch.nn.functional", "ultralytics.__version__", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.DEFAULT_CFG_KEYS", "ultralytics.utils.LOGGER", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.TORCHVISION_VERSION", "ultralytics.utils.WINDOWS", "ultralytics.utils.colorstr", "ultralytics.utils.checks.check_version", "ultralytics.utils.patches.torch_load", "ultralytics.utils.PERSISTENT_CACHE", "ultralytics.utils.autodevice.GPUInfo", "ultralytics.utils.benchmarks.ProfileModels", "thop", "thop", "cpuinfo" ], "chunk_id": "function_strip_optimizer_8410572a" }, { "content": "def convert_optimizer_state_dict_to_fp16(state_dict):\n \"\"\"\n Convert the state_dict of a given optimizer to FP16, focusing on the 'state' key for tensor conversions.\n\n Args:\n state_dict (dict): Optimizer state dictionary.\n\n Returns:\n (dict): Converted optimizer state dictionary with FP16 tensors.\n \"\"\"\n for state in state_dict[\"state\"].values():\n for k, v in state.items():\n if k != \"step\" and isinstance(v, torch.Tensor) and v.dtype is torch.float32:\n state[k] = v.half()\n\n return state_dict", "chunk_type": "function", "name": "convert_optimizer_state_dict_to_fp16", "file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py", "start_line": 769, "end_line": 784, "start_col": 0, "end_col": 21, "parent_name": null, "docstring": "Convert the state_dict of a given optimizer to FP16, focusing on the 'state' key for tensor conversions.\n\nArgs:\n state_dict (dict): Optimizer state dictionary.\n\nReturns:\n (dict): Converted optimizer state dictionary with FP16 tensors.", "parameters": [ "state_dict" ], "return_type": null, "decorators": [], "complexity_score": 4, "dependencies": [ "functools", "gc", "math", "os", "random", "time", "contextlib.contextmanager", "copy.deepcopy", "datetime.datetime", "pathlib.Path", "typing.Any", "typing.Dict", "typing.Union", "numpy", "torch", "torch.distributed", "torch.nn", "torch.nn.functional", "ultralytics.__version__", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.DEFAULT_CFG_KEYS", "ultralytics.utils.LOGGER", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.TORCHVISION_VERSION", "ultralytics.utils.WINDOWS", "ultralytics.utils.colorstr", "ultralytics.utils.checks.check_version", "ultralytics.utils.patches.torch_load", "ultralytics.utils.PERSISTENT_CACHE", "ultralytics.utils.autodevice.GPUInfo", "ultralytics.utils.benchmarks.ProfileModels", "thop", "thop", "cpuinfo" ], "chunk_id": "function_convert_optimizer_state_dict_to_fp16_a80c83e7" }, { "content": "def cuda_memory_usage(device=None):\n \"\"\"\n Monitor and manage CUDA memory usage.\n\n This function checks if CUDA is available and, if so, empties the CUDA cache to free up unused memory.\n It then yields a dictionary containing memory usage information, which can be updated by the caller.\n Finally, it updates the dictionary with the amount of memory reserved by CUDA on the specified device.\n\n Args:\n device (torch.device, optional): The CUDA device to query memory usage for.\n\n Yields:\n (dict): A dictionary with a key 'memory' initialized to 0, which will be updated with the reserved memory.\n \"\"\"\n cuda_info = dict(memory=0)\n if torch.cuda.is_available():\n torch.cuda.empty_cache()\n try:\n yield cuda_info\n finally:\n cuda_info[\"memory\"] = torch.cuda.memory_reserved(device)\n else:\n yield cuda_info", "chunk_type": "function", "name": "cuda_memory_usage", "file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py", "start_line": 788, "end_line": 810, "start_col": 0, "end_col": 23, "parent_name": null, "docstring": "Monitor and manage CUDA memory usage.\n\nThis function checks if CUDA is available and, if so, empties the CUDA cache to free up unused memory.\nIt then yields a dictionary containing memory usage information, which can be updated by the caller.\nFinally, it updates the dictionary with the amount of memory reserved by CUDA on the specified device.\n\nArgs:\n device (torch.device, optional): The CUDA device to query memory usage for.\n\nYields:\n (dict): A dictionary with a key 'memory' initialized to 0, which will be updated with the reserved memory.", "parameters": [ "device" ], "return_type": null, "decorators": [ "contextmanager" ], "complexity_score": 2, "dependencies": [ "functools", "gc", "math", "os", "random", "time", "contextlib.contextmanager", "copy.deepcopy", "datetime.datetime", "pathlib.Path", "typing.Any", "typing.Dict", "typing.Union", "numpy", "torch", "torch.distributed", "torch.nn", "torch.nn.functional", "ultralytics.__version__", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.DEFAULT_CFG_KEYS", "ultralytics.utils.LOGGER", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.TORCHVISION_VERSION", "ultralytics.utils.WINDOWS", "ultralytics.utils.colorstr", "ultralytics.utils.checks.check_version", "ultralytics.utils.patches.torch_load", "ultralytics.utils.PERSISTENT_CACHE", "ultralytics.utils.autodevice.GPUInfo", "ultralytics.utils.benchmarks.ProfileModels", "thop", "thop", "cpuinfo" ], "chunk_id": "function_cuda_memory_usage_8815d84d" }, { "content": "def profile_ops(input, ops, n=10, device=None, max_num_obj=0):\n \"\"\"\n Ultralytics speed, memory and FLOPs profiler.\n\n Args:\n input (torch.Tensor | list): Input tensor(s) to profile.\n ops (nn.Module | list): Model or list of operations to profile.\n n (int, optional): Number of iterations to average.\n device (str | torch.device, optional): Device to profile on.\n max_num_obj (int, optional): Maximum number of objects for simulation.\n\n Returns:\n (list): Profile results for each operation.\n\n Examples:\n >>> from ultralytics.utils.torch_utils import profile_ops\n >>> input = torch.randn(16, 3, 640, 640)\n >>> m1 = lambda x: x * torch.sigmoid(x)\n >>> m2 = nn.SiLU()\n >>> profile_ops(input, [m1, m2], n=100) # profile over 100 iterations\n \"\"\"\n try:\n import thop\n except ImportError:\n thop = None # conda support without 'ultralytics-thop' installed\n\n results = []\n if not isinstance(device, torch.device):\n device = select_device(device)\n LOGGER.info(\n f\"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}\"\n f\"{'input':>24s}{'output':>24s}\"\n )\n gc.collect() # attempt to free unused memory\n torch.cuda.empty_cache()\n for x in input if isinstance(input, list) else [input]:\n x = x.to(device)\n x.requires_grad = True\n for m in ops if isinstance(ops, list) else [ops]:\n m = m.to(device) if hasattr(m, \"to\") else m # device\n m = m.half() if hasattr(m, \"half\") and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m\n tf, tb, t = 0, 0, [0, 0, 0] # dt forward, backward\n try:\n flops = thop.profile(deepcopy(m), inputs=[x], verbose=False)[0] / 1e9 * 2 if thop else 0 # GFLOPs\n except Exception:\n flops = 0\n\n try:\n mem = 0\n for _ in range(n):\n with cuda_memory_usage(device) as cuda_info:\n t[0] = time_sync()\n y = m(x)\n t[1] = time_sync()\n try:\n (sum(yi.sum() for yi in y) if isinstance(y, list) else y).sum().backward()\n t[2] = time_sync()\n except Exception: # no backward method\n # print(e) # for debug\n t[2] = float(\"nan\")\n mem += cuda_info[\"memory\"] / 1e9 # (GB)\n tf += (t[1] - t[0]) * 1000 / n # ms per op forward\n tb += (t[2] - t[1]) * 1000 / n # ms per op backward\n if max_num_obj: # simulate training with predictions per image grid (for AutoBatch)\n with cuda_memory_usage(device) as cuda_info:\n torch.randn(\n x.shape[0],\n max_num_obj,\n int(sum((x.shape[-1] / s) * (x.shape[-2] / s) for s in m.stride.tolist())),\n device=device,\n dtype=torch.float32,\n )\n mem += cuda_info[\"memory\"] / 1e9 # (GB)\n s_in, s_out = (tuple(x.shape) if isinstance(x, torch.Tensor) else \"list\" for x in (x, y)) # shapes\n p = sum(x.numel() for x in m.parameters()) if isinstance(m, nn.Module) else 0 # parameters\n LOGGER.info(f\"{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}\")\n results.append([p, flops, mem, tf, tb, s_in, s_out])\n except Exception as e:\n LOGGER.info(e)\n results.append(None)\n finally:\n gc.collect() # attempt to free unused memory\n torch.cuda.empty_cache()\n return results", "chunk_type": "function", "name": "profile_ops", "file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py", "start_line": 813, "end_line": 896, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": "Ultralytics speed, memory and FLOPs profiler.\n\nArgs:\n input (torch.Tensor | list): Input tensor(s) to profile.\n ops (nn.Module | list): Model or list of operations to profile.\n n (int, optional): Number of iterations to average.\n device (str | torch.device, optional): Device to profile on.\n max_num_obj (int, optional): Maximum number of objects for simulation.\n\nReturns:\n (list): Profile results for each operation.\n\nExamples:\n >>> from ultralytics.utils.torch_utils import profile_ops\n >>> input = torch.randn(16, 3, 640, 640)\n >>> m1 = lambda x: x * torch.sigmoid(x)\n >>> m2 = nn.SiLU()\n >>> profile_ops(input, [m1, m2], n=100) # profile over 100 iterations", "parameters": [ "input", "ops", "n", "device", "max_num_obj" ], "return_type": null, "decorators": [], "complexity_score": 14, "dependencies": [ "functools", "gc", "math", "os", "random", "time", "contextlib.contextmanager", "copy.deepcopy", "datetime.datetime", "pathlib.Path", "typing.Any", "typing.Dict", "typing.Union", "numpy", "torch", "torch.distributed", "torch.nn", "torch.nn.functional", "ultralytics.__version__", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.DEFAULT_CFG_KEYS", "ultralytics.utils.LOGGER", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.TORCHVISION_VERSION", "ultralytics.utils.WINDOWS", "ultralytics.utils.colorstr", "ultralytics.utils.checks.check_version", "ultralytics.utils.patches.torch_load", "ultralytics.utils.PERSISTENT_CACHE", "ultralytics.utils.autodevice.GPUInfo", "ultralytics.utils.benchmarks.ProfileModels", "thop", "thop", "cpuinfo" ], "chunk_id": "function_profile_ops_d9138367" }, { "content": "class EarlyStopping:\n \"\"\"\n Early stopping class that stops training when a specified number of epochs have passed without improvement.\n\n Attributes:\n best_fitness (float): Best fitness value observed.\n best_epoch (int): Epoch where best fitness was observed.\n patience (int): Number of epochs to wait after fitness stops improving before stopping.\n possible_stop (bool): Flag indicating if stopping may occur next epoch.\n \"\"\"\n\n def __init__(self, patience=50):\n \"\"\"\n Initialize early stopping object.\n\n Args:\n patience (int, optional): Number of epochs to wait after fitness stops improving before stopping.\n \"\"\"\n self.best_fitness = 0.0 # i.e. mAP\n self.best_epoch = 0\n self.patience = patience or float(\"inf\") # epochs to wait after fitness stops improving to stop\n self.possible_stop = False # possible stop may occur next epoch\n\n def __call__(self, epoch, fitness):\n \"\"\"\n Check whether to stop training.\n\n Args:\n epoch (int): Current epoch of training\n fitness (float): Fitness value of current epoch\n\n Returns:\n (bool): True if training should stop, False otherwise\n \"\"\"\n if fitness is None: # check if fitness=None (happens when val=False)\n return False\n\n if fitness > self.best_fitness or self.best_fitness == 0: # allow for early zero-fitness stage of training\n self.best_epoch = epoch\n self.best_fitness = fitness\n delta = epoch - self.best_epoch # epochs without improvement\n self.possible_stop = delta >= (self.patience - 1) # possible stop may occur next epoch\n stop = delta >= self.patience # stop training if patience exceeded\n if stop:\n prefix = colorstr(\"EarlyStopping: \")\n LOGGER.info(\n f\"{prefix}Training stopped early as no improvement observed in last {self.patience} epochs. \"\n f\"Best results observed at epoch {self.best_epoch}, best model saved as best.pt.\\n\"\n f\"To update EarlyStopping(patience={self.patience}) pass a new patience value, \"\n f\"i.e. `patience=300` or use `patience=0` to disable EarlyStopping.\"\n )\n return stop", "chunk_type": "class", "name": "EarlyStopping", "file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py", "start_line": 899, "end_line": 950, "start_col": 0, "end_col": 19, "parent_name": null, "docstring": "Early stopping class that stops training when a specified number of epochs have passed without improvement.\n\nAttributes:\n best_fitness (float): Best fitness value observed.\n best_epoch (int): Epoch where best fitness was observed.\n patience (int): Number of epochs to wait after fitness stops improving before stopping.\n possible_stop (bool): Flag indicating if stopping may occur next epoch.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "functools", "gc", "math", "os", "random", "time", "contextlib.contextmanager", "copy.deepcopy", "datetime.datetime", "pathlib.Path", "typing.Any", "typing.Dict", "typing.Union", "numpy", "torch", "torch.distributed", "torch.nn", "torch.nn.functional", "ultralytics.__version__", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.DEFAULT_CFG_KEYS", "ultralytics.utils.LOGGER", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.TORCHVISION_VERSION", "ultralytics.utils.WINDOWS", "ultralytics.utils.colorstr", "ultralytics.utils.checks.check_version", "ultralytics.utils.patches.torch_load", "ultralytics.utils.PERSISTENT_CACHE", "ultralytics.utils.autodevice.GPUInfo", "ultralytics.utils.benchmarks.ProfileModels", "thop", "thop", "cpuinfo" ], "chunk_id": "class_EarlyStopping_c8b60078" }, { "content": "class FXModel(nn.Module):\n \"\"\"\n A custom model class for torch.fx compatibility.\n\n This class extends `torch.nn.Module` and is designed to ensure compatibility with torch.fx for tracing and graph\n manipulation. It copies attributes from an existing model and explicitly sets the model attribute to ensure proper\n copying.\n\n Attributes:\n model (nn.Module): The original model's layers.\n \"\"\"\n\n def __init__(self, model):\n \"\"\"\n Initialize the FXModel.\n\n Args:\n model (nn.Module): The original model to wrap for torch.fx compatibility.\n \"\"\"\n super().__init__()\n copy_attr(self, model)\n # Explicitly set `model` since `copy_attr` somehow does not copy it.\n self.model = model.model\n\n def forward(self, x):\n \"\"\"\n Forward pass through the model.\n\n This method performs the forward pass through the model, handling the dependencies between layers and saving\n intermediate outputs.\n\n Args:\n x (torch.Tensor): The input tensor to the model.\n\n Returns:\n (torch.Tensor): The output tensor from the model.\n \"\"\"\n y = [] # outputs\n for m in self.model:\n if m.f != -1: # if not from previous layer\n # from earlier layers\n x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f]\n x = m(x) # run\n y.append(x) # save output\n return x", "chunk_type": "class", "name": "FXModel", "file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py", "start_line": 953, "end_line": 997, "start_col": 0, "end_col": 16, "parent_name": null, "docstring": "A custom model class for torch.fx compatibility.\n\nThis class extends `torch.nn.Module` and is designed to ensure compatibility with torch.fx for tracing and graph\nmanipulation. It copies attributes from an existing model and explicitly sets the model attribute to ensure proper\ncopying.\n\nAttributes:\n model (nn.Module): The original model's layers.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "functools", "gc", "math", "os", "random", "time", "contextlib.contextmanager", "copy.deepcopy", "datetime.datetime", "pathlib.Path", "typing.Any", "typing.Dict", "typing.Union", "numpy", "torch", "torch.distributed", "torch.nn", "torch.nn.functional", "ultralytics.__version__", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.DEFAULT_CFG_KEYS", "ultralytics.utils.LOGGER", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.PYTHON_VERSION", "ultralytics.utils.TORCHVISION_VERSION", "ultralytics.utils.WINDOWS", "ultralytics.utils.colorstr", "ultralytics.utils.checks.check_version", "ultralytics.utils.patches.torch_load", "ultralytics.utils.PERSISTENT_CACHE", "ultralytics.utils.autodevice.GPUInfo", "ultralytics.utils.benchmarks.ProfileModels", "thop", "thop", "cpuinfo", "nn.Module" ], "chunk_id": "class_FXModel_592a07c2" }, { "content": "from typing import List", "chunk_type": "import", "name": "List", "file_path": "ultralytics\\ultralytics\\utils\\triton.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 23, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_List_8c3ed700" }, { "content": "from urllib.parse import urlsplit", "chunk_type": "import", "name": "urlsplit", "file_path": "ultralytics\\ultralytics\\utils\\triton.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 33, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_urlsplit_621806e4" }, { "content": "import numpy as np", "chunk_type": "import", "name": "numpy", "file_path": "ultralytics\\ultralytics\\utils\\triton.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_numpy_727d2609" }, { "content": "class TritonRemoteModel:\n \"\"\"\n Client for interacting with a remote Triton Inference Server model.\n\n This class provides a convenient interface for sending inference requests to a Triton Inference Server\n and processing the responses. Supports both HTTP and gRPC communication protocols.\n\n Attributes:\n endpoint (str): The name of the model on the Triton server.\n url (str): The URL of the Triton server.\n triton_client: The Triton client (either HTTP or gRPC).\n InferInput: The input class for the Triton client.\n InferRequestedOutput: The output request class for the Triton client.\n input_formats (List[str]): The data types of the model inputs.\n np_input_formats (List[type]): The numpy data types of the model inputs.\n input_names (List[str]): The names of the model inputs.\n output_names (List[str]): The names of the model outputs.\n metadata: The metadata associated with the model.\n\n Methods:\n __call__: Call the model with the given inputs and return the outputs.\n\n Examples:\n Initialize a Triton client with HTTP\n >>> model = TritonRemoteModel(url=\"localhost:8000\", endpoint=\"yolov8\", scheme=\"http\")\n\n Make inference with numpy arrays\n >>> outputs = model(np.random.rand(1, 3, 640, 640).astype(np.float32))\n \"\"\"\n\n def __init__(self, url: str, endpoint: str = \"\", scheme: str = \"\"):\n \"\"\"\n Initialize the TritonRemoteModel for interacting with a remote Triton Inference Server.\n\n Arguments may be provided individually or parsed from a collective 'url' argument of the form\n :////\n\n Args:\n url (str): The URL of the Triton server.\n endpoint (str, optional): The name of the model on the Triton server.\n scheme (str, optional): The communication scheme ('http' or 'grpc').\n\n Examples:\n >>> model = TritonRemoteModel(url=\"localhost:8000\", endpoint=\"yolov8\", scheme=\"http\")\n >>> model = TritonRemoteModel(url=\"http://localhost:8000/yolov8\")\n \"\"\"\n if not endpoint and not scheme: # Parse all args from URL string\n splits = urlsplit(url)\n endpoint = splits.path.strip(\"/\").split(\"/\", 1)[0]\n scheme = splits.scheme\n url = splits.netloc\n\n self.endpoint = endpoint\n self.url = url\n\n # Choose the Triton client based on the communication scheme\n if scheme == \"http\":\n import tritonclient.http as client # noqa\n\n self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False)\n config = self.triton_client.get_model_config(endpoint)\n else:\n import tritonclient.grpc as client # noqa\n\n self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False)\n config = self.triton_client.get_model_config(endpoint, as_json=True)[\"config\"]\n\n # Sort output names alphabetically, i.e. 'output0', 'output1', etc.\n config[\"output\"] = sorted(config[\"output\"], key=lambda x: x.get(\"name\"))\n\n # Define model attributes\n type_map = {\"TYPE_FP32\": np.float32, \"TYPE_FP16\": np.float16, \"TYPE_UINT8\": np.uint8}\n self.InferRequestedOutput = client.InferRequestedOutput\n self.InferInput = client.InferInput\n self.input_formats = [x[\"data_type\"] for x in config[\"input\"]]\n self.np_input_formats = [type_map[x] for x in self.input_formats]\n self.input_names = [x[\"name\"] for x in config[\"input\"]]\n self.output_names = [x[\"name\"] for x in config[\"output\"]]\n self.metadata = eval(config.get(\"parameters\", {}).get(\"metadata\", {}).get(\"string_value\", \"None\"))\n\n def __call__(self, *inputs: np.ndarray) -> List[np.ndarray]:\n \"\"\"\n Call the model with the given inputs and return inference results.\n\n Args:\n *inputs (np.ndarray): Input data to the model. Each array should match the expected shape and type\n for the corresponding model input.\n\n Returns:\n (List[np.ndarray]): Model outputs with the same dtype as the input. Each element in the list\n corresponds to one of the model's output tensors.\n\n Examples:\n >>> model = TritonRemoteModel(url=\"localhost:8000\", endpoint=\"yolov8\", scheme=\"http\")\n >>> outputs = model(np.random.rand(1, 3, 640, 640).astype(np.float32))\n \"\"\"\n infer_inputs = []\n input_format = inputs[0].dtype\n for i, x in enumerate(inputs):\n if x.dtype != self.np_input_formats[i]:\n x = x.astype(self.np_input_formats[i])\n infer_input = self.InferInput(self.input_names[i], [*x.shape], self.input_formats[i].replace(\"TYPE_\", \"\"))\n infer_input.set_data_from_numpy(x)\n infer_inputs.append(infer_input)\n\n infer_outputs = [self.InferRequestedOutput(output_name) for output_name in self.output_names]\n outputs = self.triton_client.infer(model_name=self.endpoint, inputs=infer_inputs, outputs=infer_outputs)\n\n return [outputs.as_numpy(output_name).astype(input_format) for output_name in self.output_names]", "chunk_type": "class", "name": "TritonRemoteModel", "file_path": "ultralytics\\ultralytics\\utils\\triton.py", "start_line": 9, "end_line": 117, "start_col": 0, "end_col": 104, "parent_name": null, "docstring": "Client for interacting with a remote Triton Inference Server model.\n\nThis class provides a convenient interface for sending inference requests to a Triton Inference Server\nand processing the responses. Supports both HTTP and gRPC communication protocols.\n\nAttributes:\n endpoint (str): The name of the model on the Triton server.\n url (str): The URL of the Triton server.\n triton_client: The Triton client (either HTTP or gRPC).\n InferInput: The input class for the Triton client.\n InferRequestedOutput: The output request class for the Triton client.\n input_formats (List[str]): The data types of the model inputs.\n np_input_formats (List[type]): The numpy data types of the model inputs.\n input_names (List[str]): The names of the model inputs.\n output_names (List[str]): The names of the model outputs.\n metadata: The metadata associated with the model.\n\nMethods:\n __call__: Call the model with the given inputs and return the outputs.\n\nExamples:\n Initialize a Triton client with HTTP\n >>> model = TritonRemoteModel(url=\"localhost:8000\", endpoint=\"yolov8\", scheme=\"http\")\n\n Make inference with numpy arrays\n >>> outputs = model(np.random.rand(1, 3, 640, 640).astype(np.float32))", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.List", "urllib.parse.urlsplit", "numpy", "tritonclient.http", "tritonclient.grpc" ], "chunk_id": "class_TritonRemoteModel_b8e84f73" }, { "content": "from ultralytics.cfg import TASK2DATA, TASK2METRIC, get_cfg, get_save_dir", "chunk_type": "import", "name": "TASK2DATA, TASK2METRIC, get_cfg, get_save_dir", "file_path": "ultralytics\\ultralytics\\utils\\tuner.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 73, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_TASK2DATA, TASK2METRIC, get_cfg, get_save_dir_7c2ec1a2" }, { "content": "from ultralytics.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, NUM_THREADS, checks, colorstr", "chunk_type": "import", "name": "DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, NUM_THREADS, checks, colorstr", "file_path": "ultralytics\\ultralytics\\utils\\tuner.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 98, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, NUM_THREADS, checks, colorstr_e9eb42a8" }, { "content": "def run_ray_tune(\n model,\n space: dict = None,\n grace_period: int = 10,\n gpu_per_trial: int = None,\n max_samples: int = 10,\n **train_args,\n):\n \"\"\"\n Run hyperparameter tuning using Ray Tune.\n\n Args:\n model (YOLO): Model to run the tuner on.\n space (dict, optional): The hyperparameter search space. If not provided, uses default space.\n grace_period (int, optional): The grace period in epochs of the ASHA scheduler.\n gpu_per_trial (int, optional): The number of GPUs to allocate per trial.\n max_samples (int, optional): The maximum number of trials to run.\n **train_args (Any): Additional arguments to pass to the `train()` method.\n\n Returns:\n (ray.tune.ResultGrid): A ResultGrid containing the results of the hyperparameter search.\n\n Examples:\n >>> from ultralytics import YOLO\n >>> model = YOLO(\"yolo11n.pt\") # Load a YOLO11n model\n\n Start tuning hyperparameters for YOLO11n training on the COCO8 dataset\n >>> result_grid = model.tune(data=\"coco8.yaml\", use_ray=True)\n \"\"\"\n LOGGER.info(\"💡 Learn about RayTune at https://docs.ultralytics.com/integrations/ray-tune\")\n if train_args is None:\n train_args = {}\n\n try:\n checks.check_requirements(\"ray[tune]\")\n\n import ray\n from ray import tune\n from ray.air import RunConfig\n from ray.air.integrations.wandb import WandbLoggerCallback\n from ray.tune.schedulers import ASHAScheduler\n except ImportError:\n raise ModuleNotFoundError('Ray Tune required but not found. To install run: pip install \"ray[tune]\"')\n\n try:\n import wandb\n\n assert hasattr(wandb, \"__version__\")\n except (ImportError, AssertionError):\n wandb = False\n\n checks.check_version(ray.__version__, \">=2.0.0\", \"ray\")\n default_space = {\n # 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']),\n \"lr0\": tune.uniform(1e-5, 1e-1),\n \"lrf\": tune.uniform(0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)\n \"momentum\": tune.uniform(0.6, 0.98), # SGD momentum/Adam beta1\n \"weight_decay\": tune.uniform(0.0, 0.001), # optimizer weight decay\n \"warmup_epochs\": tune.uniform(0.0, 5.0), # warmup epochs (fractions ok)\n \"warmup_momentum\": tune.uniform(0.0, 0.95), # warmup initial momentum\n \"box\": tune.uniform(0.02, 0.2), # box loss gain\n \"cls\": tune.uniform(0.2, 4.0), # cls loss gain (scale with pixels)\n \"hsv_h\": tune.uniform(0.0, 0.1), # image HSV-Hue augmentation (fraction)\n \"hsv_s\": tune.uniform(0.0, 0.9), # image HSV-Saturation augmentation (fraction)\n \"hsv_v\": tune.uniform(0.0, 0.9), # image HSV-Value augmentation (fraction)\n \"degrees\": tune.uniform(0.0, 45.0), # image rotation (+/- deg)\n \"translate\": tune.uniform(0.0, 0.9), # image translation (+/- fraction)\n \"scale\": tune.uniform(0.0, 0.9), # image scale (+/- gain)\n \"shear\": tune.uniform(0.0, 10.0), # image shear (+/- deg)\n \"perspective\": tune.uniform(0.0, 0.001), # image perspective (+/- fraction), range 0-0.001\n \"flipud\": tune.uniform(0.0, 1.0), # image flip up-down (probability)\n \"fliplr\": tune.uniform(0.0, 1.0), # image flip left-right (probability)\n \"bgr\": tune.uniform(0.0, 1.0), # image channel BGR (probability)\n \"mosaic\": tune.uniform(0.0, 1.0), # image mosaic (probability)\n \"mixup\": tune.uniform(0.0, 1.0), # image mixup (probability)\n \"cutmix\": tune.uniform(0.0, 1.0), # image cutmix (probability)\n \"copy_paste\": tune.uniform(0.0, 1.0), # segment copy-paste (probability)\n }\n\n # Put the model in ray store\n task = model.task\n model_in_store = ray.put(model)\n\n def _tune(config):\n \"\"\"Train the YOLO model with the specified hyperparameters and return results.\"\"\"\n model_to_train = ray.get(model_in_store) # get the model from ray store for tuning\n model_to_train.reset_callbacks()\n config.update(train_args)\n results = model_to_train.train(**config)\n return results.results_dict\n\n # Get search space\n if not space and not train_args.get(\"resume\"):\n space = default_space\n LOGGER.warning(\"Search space not provided, using default search space.\")\n\n # Get dataset\n data = train_args.get(\"data\", TASK2DATA[task])\n space[\"data\"] = data\n if \"data\" not in train_args:\n LOGGER.warning(f'Data not provided, using default \"data={data}\".')\n\n # Define the trainable function with allocated resources\n trainable_with_resources = tune.with_resources(_tune, {\"cpu\": NUM_THREADS, \"gpu\": gpu_per_trial or 0})\n\n # Define the ASHA scheduler for hyperparameter search\n asha_scheduler = ASHAScheduler(\n time_attr=\"epoch\",\n metric=TASK2METRIC[task],\n mode=\"max\",\n max_t=train_args.get(\"epochs\") or DEFAULT_CFG_DICT[\"epochs\"] or 100,\n grace_period=grace_period,\n reduction_factor=3,\n )\n\n # Define the callbacks for the hyperparameter search\n tuner_callbacks = [WandbLoggerCallback(project=\"YOLOv8-tune\")] if wandb else []\n\n # Create the Ray Tune hyperparameter search tuner\n tune_dir = get_save_dir(\n get_cfg(\n DEFAULT_CFG,\n {**train_args, **{\"exist_ok\": train_args.pop(\"resume\", False)}}, # resume w/ same tune_dir\n ),\n name=train_args.pop(\"name\", \"tune\"), # runs/{task}/{tune_dir}\n ).resolve() # must be absolute dir\n tune_dir.mkdir(parents=True, exist_ok=True)\n if tune.Tuner.can_restore(tune_dir):\n LOGGER.info(f\"{colorstr('Tuner: ')} Resuming tuning run {tune_dir}...\")\n tuner = tune.Tuner.restore(str(tune_dir), trainable=trainable_with_resources, resume_errored=True)\n else:\n tuner = tune.Tuner(\n trainable_with_resources,\n param_space=space,\n tune_config=tune.TuneConfig(\n scheduler=asha_scheduler,\n num_samples=max_samples,\n trial_name_creator=lambda trial: f\"{trial.trainable_name}_{trial.trial_id}\",\n trial_dirname_creator=lambda trial: f\"{trial.trainable_name}_{trial.trial_id}\",\n ),\n run_config=RunConfig(callbacks=tuner_callbacks, storage_path=tune_dir.parent, name=tune_dir.name),\n )\n\n # Run the hyperparameter search\n tuner.fit()\n\n # Get the results of the hyperparameter search\n results = tuner.get_results()\n\n # Shut down Ray to clean up workers\n ray.shutdown()\n\n return results", "chunk_type": "function", "name": "run_ray_tune", "file_path": "ultralytics\\ultralytics\\utils\\tuner.py", "start_line": 7, "end_line": 159, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": "Run hyperparameter tuning using Ray Tune.\n\nArgs:\n model (YOLO): Model to run the tuner on.\n space (dict, optional): The hyperparameter search space. If not provided, uses default space.\n grace_period (int, optional): The grace period in epochs of the ASHA scheduler.\n gpu_per_trial (int, optional): The number of GPUs to allocate per trial.\n max_samples (int, optional): The maximum number of trials to run.\n **train_args (Any): Additional arguments to pass to the `train()` method.\n\nReturns:\n (ray.tune.ResultGrid): A ResultGrid containing the results of the hyperparameter search.\n\nExamples:\n >>> from ultralytics import YOLO\n >>> model = YOLO(\"yolo11n.pt\") # Load a YOLO11n model\n\n Start tuning hyperparameters for YOLO11n training on the COCO8 dataset\n >>> result_grid = model.tune(data=\"coco8.yaml\", use_ray=True)", "parameters": [ "model", "space: dict", "grace_period: int", "gpu_per_trial: int", "max_samples: int" ], "return_type": null, "decorators": [], "complexity_score": 7, "dependencies": [ "ultralytics.cfg.TASK2DATA", "ultralytics.cfg.TASK2METRIC", "ultralytics.cfg.get_cfg", "ultralytics.cfg.get_save_dir", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.LOGGER", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.checks", "ultralytics.utils.colorstr", "ray", "ray.tune", "ray.air.RunConfig", "ray.air.integrations.wandb.WandbLoggerCallback", "ray.tune.schedulers.ASHAScheduler", "wandb" ], "chunk_id": "function_run_ray_tune_8bb9d330" }, { "content": "import contextlib", "chunk_type": "import", "name": "contextlib", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 17, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_contextlib_4d4c47c8" }, { "content": "import importlib.metadata", "chunk_type": "import", "name": "importlib.metadata", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 25, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_importlib.metadata_1f56d5eb" }, { "content": "import inspect", "chunk_type": "import", "name": "inspect", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 14, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_inspect_8238d903" }, { "content": "import json", "chunk_type": "import", "name": "json", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 11, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_json_a7821fa1" }, { "content": "import logging", "chunk_type": "import", "name": "logging", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 14, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_logging_2ff27b9b" }, { "content": "import os", "chunk_type": "import", "name": "os", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 9, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_os_a474651e" }, { "content": "import platform", "chunk_type": "import", "name": "platform", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 15, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_platform_aaf3ae19" }, { "content": "import re", "chunk_type": "import", "name": "re", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 10, "end_line": 10, "start_col": 0, "end_col": 9, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_re_1aa51818" }, { "content": "import subprocess", "chunk_type": "import", "name": "subprocess", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 11, "end_line": 11, "start_col": 0, "end_col": 17, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_subprocess_c14a63da" }, { "content": "import sys", "chunk_type": "import", "name": "sys", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 12, "end_line": 12, "start_col": 0, "end_col": 10, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_sys_7e23cba0" }, { "content": "import threading", "chunk_type": "import", "name": "threading", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 13, "end_line": 13, "start_col": 0, "end_col": 16, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_threading_defa8e84" }, { "content": "import time", "chunk_type": "import", "name": "time", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 14, "end_line": 14, "start_col": 0, "end_col": 11, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_time_5b823329" }, { "content": "import warnings", "chunk_type": "import", "name": "warnings", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 15, "end_line": 15, "start_col": 0, "end_col": 15, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_warnings_22e1693e" }, { "content": "from pathlib import Path", "chunk_type": "import", "name": "Path", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 16, "end_line": 16, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Path_24ab0945" }, { "content": "from threading import Lock", "chunk_type": "import", "name": "Lock", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 17, "end_line": 17, "start_col": 0, "end_col": 26, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Lock_b62ee358" }, { "content": "from types import SimpleNamespace", "chunk_type": "import", "name": "SimpleNamespace", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 18, "end_line": 18, "start_col": 0, "end_col": 33, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_SimpleNamespace_c9f5c5d7" }, { "content": "from typing import Union", "chunk_type": "import", "name": "Union", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 19, "end_line": 19, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Union_7b3abf18" }, { "content": "from urllib.parse import unquote", "chunk_type": "import", "name": "unquote", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 20, "end_line": 20, "start_col": 0, "end_col": 32, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_unquote_b5450cb4" }, { "content": "import cv2", "chunk_type": "import", "name": "cv2", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 22, "end_line": 22, "start_col": 0, "end_col": 10, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_cv2_377856b6" }, { "content": "import numpy as np", "chunk_type": "import", "name": "numpy", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 23, "end_line": 23, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_numpy_615d6760" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 24, "end_line": 24, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_6e178de0" }, { "content": "import tqdm", "chunk_type": "import", "name": "tqdm", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 25, "end_line": 25, "start_col": 0, "end_col": 11, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_tqdm_fc1d95d4" }, { "content": "from ultralytics import __version__", "chunk_type": "import", "name": "__version__", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 27, "end_line": 27, "start_col": 0, "end_col": 35, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import___version___ea70b14d" }, { "content": "from ultralytics.utils.patches import imread, imshow, imwrite, torch_save # for patches", "chunk_type": "import", "name": "imread, imshow, imwrite, torch_save", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 28, "end_line": 28, "start_col": 0, "end_col": 73, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_imread, imshow, imwrite, torch_save_f1dceae8" }, { "content": "RANK = int(os.getenv(\"RANK\", -1))", "chunk_type": "variable", "name": "RANK", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 31, "end_line": 31, "start_col": 0, "end_col": 33, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_RANK_9e3ef401" }, { "content": "LOCAL_RANK = int(os.getenv(\"LOCAL_RANK\", -1)) # https://pytorch.org/docs/stable/elastic/run.html", "chunk_type": "variable", "name": "LOCAL_RANK", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 32, "end_line": 32, "start_col": 0, "end_col": 45, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_LOCAL_RANK_90f2524a" }, { "content": "ARGV = sys.argv or [\"\", \"\"] # sometimes sys.argv = []", "chunk_type": "variable", "name": "ARGV", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 35, "end_line": 35, "start_col": 0, "end_col": 27, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_ARGV_8d0e7e0e" }, { "content": "FILE = Path(__file__).resolve()", "chunk_type": "variable", "name": "FILE", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 36, "end_line": 36, "start_col": 0, "end_col": 31, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_FILE_835ae00c" }, { "content": "ROOT = FILE.parents[1] # YOLO", "chunk_type": "variable", "name": "ROOT", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 37, "end_line": 37, "start_col": 0, "end_col": 22, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_ROOT_ba9cd46f" }, { "content": "ASSETS = ROOT / \"assets\" # default images", "chunk_type": "variable", "name": "ASSETS", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 38, "end_line": 38, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_ASSETS_dc99f744" }, { "content": "ASSETS_URL = \"https://github.com/ultralytics/assets/releases/download/v0.0.0\" # assets GitHub URL", "chunk_type": "variable", "name": "ASSETS_URL", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 39, "end_line": 39, "start_col": 0, "end_col": 77, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_ASSETS_URL_7853b028" }, { "content": "DEFAULT_CFG_PATH = ROOT / \"cfg/default.yaml\"", "chunk_type": "variable", "name": "DEFAULT_CFG_PATH", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 40, "end_line": 40, "start_col": 0, "end_col": 44, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_DEFAULT_CFG_PATH_2e90f729" }, { "content": "NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLO multiprocessing threads", "chunk_type": "variable", "name": "NUM_THREADS", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 41, "end_line": 41, "start_col": 0, "end_col": 48, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_NUM_THREADS_640ac97b" }, { "content": "AUTOINSTALL = str(os.getenv(\"YOLO_AUTOINSTALL\", True)).lower() == \"true\" # global auto-install mode", "chunk_type": "variable", "name": "AUTOINSTALL", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 42, "end_line": 42, "start_col": 0, "end_col": 72, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_AUTOINSTALL_66403f21" }, { "content": "VERBOSE = str(os.getenv(\"YOLO_VERBOSE\", True)).lower() == \"true\" # global verbose mode", "chunk_type": "variable", "name": "VERBOSE", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 43, "end_line": 43, "start_col": 0, "end_col": 64, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_VERBOSE_a14f86bb" }, { "content": "TQDM_BAR_FORMAT = \"{l_bar}{bar:10}{r_bar}\" if VERBOSE else None # tqdm bar format", "chunk_type": "variable", "name": "TQDM_BAR_FORMAT", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 44, "end_line": 44, "start_col": 0, "end_col": 63, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_TQDM_BAR_FORMAT_16e9892d" }, { "content": "LOGGING_NAME = \"ultralytics\"", "chunk_type": "variable", "name": "LOGGING_NAME", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 45, "end_line": 45, "start_col": 0, "end_col": 28, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_LOGGING_NAME_c6583563" }, { "content": "MACOS_VERSION = platform.mac_ver()[0] if MACOS else None", "chunk_type": "variable", "name": "MACOS_VERSION", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 47, "end_line": 47, "start_col": 0, "end_col": 56, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_MACOS_VERSION_72e4c96b" }, { "content": "ARM64 = platform.machine() in {\"arm64\", \"aarch64\"} # ARM64 booleans", "chunk_type": "variable", "name": "ARM64", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 48, "end_line": 48, "start_col": 0, "end_col": 50, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_ARM64_f4087415" }, { "content": "PYTHON_VERSION = platform.python_version()", "chunk_type": "variable", "name": "PYTHON_VERSION", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 49, "end_line": 49, "start_col": 0, "end_col": 42, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_PYTHON_VERSION_229c3f66" }, { "content": "TORCH_VERSION = torch.__version__", "chunk_type": "variable", "name": "TORCH_VERSION", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 50, "end_line": 50, "start_col": 0, "end_col": 33, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_TORCH_VERSION_f77f9b49" }, { "content": "TORCHVISION_VERSION = importlib.metadata.version(\"torchvision\") # faster than importing torchvision", "chunk_type": "variable", "name": "TORCHVISION_VERSION", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 51, "end_line": 51, "start_col": 0, "end_col": 63, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_TORCHVISION_VERSION_6646c490" }, { "content": "IS_VSCODE = os.environ.get(\"TERM_PROGRAM\", False) == \"vscode\"", "chunk_type": "variable", "name": "IS_VSCODE", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 52, "end_line": 52, "start_col": 0, "end_col": 61, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_IS_VSCODE_78a0d751" }, { "content": "RKNN_CHIPS = frozenset(\n {\n \"rk3588\",\n \"rk3576\",\n \"rk3566\",\n \"rk3568\",\n \"rk3562\",\n \"rv1103\",\n \"rv1106\",\n \"rv1103b\",\n \"rv1106b\",\n \"rk2118\",\n }\n) # Rockchip processors available for export", "chunk_type": "variable", "name": "RKNN_CHIPS", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 53, "end_line": 66, "start_col": 0, "end_col": 1, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_RKNN_CHIPS_7a7f5d56" }, { "content": "HELP_MSG = \"\"\"\n Examples for running Ultralytics:\n\n 1. Install the ultralytics package:\n\n pip install ultralytics\n\n 2. Use the Python SDK:\n\n from ultralytics import YOLO\n\n # Load a model\n model = YOLO(\"yolo11n.yaml\") # build a new model from scratch\n model = YOLO(\"yolo11n.pt\") # load a pretrained model (recommended for training)\n\n # Use the model\n results = model.train(data=\"coco8.yaml\", epochs=3) # train the model\n results = model.val() # evaluate model performance on the validation set\n results = model(\"https://ultralytics.com/images/bus.jpg\") # predict on an image\n success = model.export(format=\"onnx\") # export the model to ONNX format\n\n 3. Use the command line interface (CLI):\n\n Ultralytics 'yolo' CLI commands use the following syntax:\n\n yolo TASK MODE ARGS\n\n Where TASK (optional) is one of [detect, segment, classify, pose, obb]\n MODE (required) is one of [train, val, predict, export, track, benchmark]\n ARGS (optional) are any number of custom \"arg=value\" pairs like \"imgsz=320\" that override defaults.\n See all ARGS at https://docs.ultralytics.com/usage/cfg or with \"yolo cfg\"\n\n - Train a detection model for 10 epochs with an initial learning_rate of 0.01\n yolo detect train data=coco8.yaml model=yolo11n.pt epochs=10 lr0=0.01\n\n - Predict a YouTube video using a pretrained segmentation model at image size 320:\n yolo segment predict model=yolo11n-seg.pt source='https://youtu.be/LNwODJXcvt4' imgsz=320\n\n - Val a pretrained detection model at batch-size 1 and image size 640:\n yolo detect val model=yolo11n.pt data=coco8.yaml batch=1 imgsz=640\n\n - Export a YOLO11n classification model to ONNX format at image size 224 by 128 (no TASK required)\n yolo export model=yolo11n-cls.pt format=onnx imgsz=224,128\n\n - Run special commands:\n yolo help\n yolo checks\n yolo version\n yolo settings\n yolo copy-cfg\n yolo cfg\n\n Docs: https://docs.ultralytics.com\n Community: https://community.ultralytics.com\n GitHub: https://github.com/ultralytics/ultralytics\n \"\"\"", "chunk_type": "variable", "name": "HELP_MSG", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 67, "end_line": 122, "start_col": 0, "end_col": 7, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_HELP_MSG_5239f3ea" }, { "content": "class TQDM(rich.tqdm if TQDM_RICH else tqdm.tqdm):\n \"\"\"\n A custom TQDM progress bar class that extends the original tqdm functionality.\n\n This class modifies the behavior of the original tqdm progress bar based on global settings and provides\n additional customization options for Ultralytics projects. The progress bar is automatically disabled when\n VERBOSE is False or when explicitly disabled.\n\n Attributes:\n disable (bool): Whether to disable the progress bar. Determined by the global VERBOSE setting and\n any passed 'disable' argument.\n bar_format (str): The format string for the progress bar. Uses the global TQDM_BAR_FORMAT if not\n explicitly set.\n\n Methods:\n __init__: Initialize the TQDM object with custom settings.\n __iter__: Return self as iterator to satisfy Iterable interface.\n\n Examples:\n >>> from ultralytics.utils import TQDM\n >>> for i in TQDM(range(100)):\n ... # Your processing code here\n ... pass\n \"\"\"\n\n def __init__(self, *args, **kwargs):\n \"\"\"\n Initialize a custom TQDM progress bar with Ultralytics-specific settings.\n\n Args:\n *args (Any): Variable length argument list to be passed to the original tqdm constructor.\n **kwargs (Any): Arbitrary keyword arguments to be passed to the original tqdm constructor.\n\n Notes:\n - The progress bar is disabled if VERBOSE is False or if 'disable' is explicitly set to True in kwargs.\n - The default bar format is set to TQDM_BAR_FORMAT unless overridden in kwargs.\n\n Examples:\n >>> from ultralytics.utils import TQDM\n >>> for i in TQDM(range(100)):\n ... # Your code here\n ... pass\n \"\"\"\n warnings.filterwarnings(\"ignore\", category=tqdm.TqdmExperimentalWarning) # suppress tqdm.rich warning\n kwargs[\"disable\"] = not VERBOSE or kwargs.get(\"disable\", False)\n kwargs.setdefault(\"bar_format\", TQDM_BAR_FORMAT) # override default value if passed\n super().__init__(*args, **kwargs)\n\n def __iter__(self):\n \"\"\"Return self as iterator to satisfy Iterable interface.\"\"\"\n return super().__iter__()", "chunk_type": "class", "name": "TQDM", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 137, "end_line": 187, "start_col": 0, "end_col": 33, "parent_name": null, "docstring": "A custom TQDM progress bar class that extends the original tqdm functionality.\n\nThis class modifies the behavior of the original tqdm progress bar based on global settings and provides\nadditional customization options for Ultralytics projects. The progress bar is automatically disabled when\nVERBOSE is False or when explicitly disabled.\n\nAttributes:\n disable (bool): Whether to disable the progress bar. Determined by the global VERBOSE setting and\n any passed 'disable' argument.\n bar_format (str): The format string for the progress bar. Uses the global TQDM_BAR_FORMAT if not\n explicitly set.\n\nMethods:\n __init__: Initialize the TQDM object with custom settings.\n __iter__: Return self as iterator to satisfy Iterable interface.\n\nExamples:\n >>> from ultralytics.utils import TQDM\n >>> for i in TQDM(range(100)):\n ... # Your processing code here\n ... pass", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "contextlib", "importlib.metadata", "inspect", "json", "logging", "os", "platform", "re", "subprocess", "sys", "threading", "time", "warnings", "pathlib.Path", "threading.Lock", "types.SimpleNamespace", "typing.Union", "urllib.parse.unquote", "cv2", "numpy", "torch", "tqdm", "ultralytics.__version__", "ultralytics.utils.patches.imread", "ultralytics.utils.patches.imshow", "ultralytics.utils.patches.imwrite", "ultralytics.utils.patches.torch_save", "tqdm.rich", "importlib.util", "pandas", "sqlite3", "functools.wraps", "yaml", "socket", "sentry_sdk", "hashlib", "uuid", "ultralytics.utils.torch_utils.torch_distributed_zero_first", "matplotlib.pyplot", "io", "rich.tqdm if TQDM_RICH else tqdm.tqdm" ], "chunk_id": "class_TQDM_87fbee36" }, { "content": "class DataExportMixin:\n \"\"\"\n Mixin class for exporting validation metrics or prediction results in various formats.\n\n This class provides utilities to export performance metrics (e.g., mAP, precision, recall) or prediction results\n from classification, object detection, segmentation, or pose estimation tasks into various formats: Pandas\n DataFrame, CSV, XML, HTML, JSON and SQLite (SQL).\n\n Methods:\n to_df: Convert summary to a Pandas DataFrame.\n to_csv: Export results as a CSV string.\n to_xml: Export results as an XML string (requires `lxml`).\n to_html: Export results as an HTML table.\n to_json: Export results as a JSON string.\n tojson: Deprecated alias for `to_json()`.\n to_sql: Export results to an SQLite database.\n\n Examples:\n >>> model = YOLO(\"yolo11n.pt\")\n >>> results = model(\"image.jpg\")\n >>> df = results.to_df()\n >>> print(df)\n >>> csv_data = results.to_csv()\n >>> results.to_sql(table_name=\"yolo_results\")\n \"\"\"\n\n def to_df(self, normalize=False, decimals=5):\n \"\"\"\n Create a pandas DataFrame from the prediction results summary or validation metrics.\n\n Args:\n normalize (bool, optional): Normalize numerical values for easier comparison.\n decimals (int, optional): Decimal places to round floats.\n\n Returns:\n (DataFrame): DataFrame containing the summary data.\n \"\"\"\n import pandas as pd # scope for faster 'import ultralytics'\n\n return pd.DataFrame(self.summary(normalize=normalize, decimals=decimals))\n\n def to_csv(self, normalize=False, decimals=5):\n \"\"\"\n Export results to CSV string format.\n\n Args:\n normalize (bool, optional): Normalize numeric values.\n decimals (int, optional): Decimal precision.\n\n Returns:\n (str): CSV content as string.\n \"\"\"\n return self.to_df(normalize=normalize, decimals=decimals).to_csv()\n\n def to_xml(self, normalize=False, decimals=5):\n \"\"\"\n Export results to XML format.\n\n Args:\n normalize (bool, optional): Normalize numeric values.\n decimals (int, optional): Decimal precision.\n\n Returns:\n (str): XML string.\n\n Notes:\n Requires `lxml` package to be installed.\n \"\"\"\n df = self.to_df(normalize=normalize, decimals=decimals)\n return '\\n' if df.empty else df.to_xml(parser=\"etree\")\n\n def to_html(self, normalize=False, decimals=5, index=False):\n \"\"\"\n Export results to HTML table format.\n\n Args:\n normalize (bool, optional): Normalize numeric values.\n decimals (int, optional): Decimal precision.\n index (bool, optional): Whether to include index column in the HTML table.\n\n Returns:\n (str): HTML representation of the results.\n \"\"\"\n df = self.to_df(normalize=normalize, decimals=decimals)\n return \"
\" if df.empty else df.to_html(index=index)\n\n def tojson(self, normalize=False, decimals=5):\n \"\"\"Deprecated version of to_json().\"\"\"\n LOGGER.warning(\"'result.tojson()' is deprecated, replace with 'result.to_json()'.\")\n return self.to_json(normalize, decimals)\n\n def to_json(self, normalize=False, decimals=5):\n \"\"\"\n Export results to JSON format.\n\n Args:\n normalize (bool, optional): Normalize numeric values.\n decimals (int, optional): Decimal precision.\n\n Returns:\n (str): JSON-formatted string of the results.\n \"\"\"\n return self.to_df(normalize=normalize, decimals=decimals).to_json(orient=\"records\", indent=2)\n\n def to_sql(self, normalize=False, decimals=5, table_name=\"results\", db_path=\"results.db\"):\n \"\"\"\n Save results to an SQLite database.\n\n Args:\n normalize (bool, optional): Normalize numeric values.\n decimals (int, optional): Decimal precision.\n table_name (str, optional): Name of the SQL table.\n db_path (str, optional): SQLite database file path.\n \"\"\"\n df = self.to_df(normalize, decimals)\n if df.empty or df.columns.empty: # Exit if df is None or has no columns (i.e., no schema)\n return\n\n import sqlite3\n\n conn = sqlite3.connect(db_path)\n cursor = conn.cursor()\n\n # Dynamically create table schema based on summary to support prediction and validation results export\n columns = []\n for col in df.columns:\n sample_val = df[col].dropna().iloc[0] if not df[col].dropna().empty else \"\"\n if isinstance(sample_val, dict):\n col_type = \"TEXT\"\n elif isinstance(sample_val, (float, int)):\n col_type = \"REAL\"\n else:\n col_type = \"TEXT\"\n columns.append(f'\"{col}\" {col_type}') # Quote column names to handle special characters like hyphens\n\n # Create table (Drop table from db if it's already exist)\n cursor.execute(f'DROP TABLE IF EXISTS \"{table_name}\"')\n cursor.execute(f'CREATE TABLE \"{table_name}\" (id INTEGER PRIMARY KEY AUTOINCREMENT, {\", \".join(columns)})')\n\n for _, row in df.iterrows():\n values = [json.dumps(v) if isinstance(v, dict) else v for v in row]\n column_names = \", \".join(f'\"{col}\"' for col in df.columns)\n placeholders = \", \".join(\"?\" for _ in df.columns)\n cursor.execute(f'INSERT INTO \"{table_name}\" ({column_names}) VALUES ({placeholders})', values)\n\n conn.commit()\n conn.close()\n LOGGER.info(f\"Results saved to SQL table '{table_name}' in '{db_path}'.\")", "chunk_type": "class", "name": "DataExportMixin", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 190, "end_line": 337, "start_col": 0, "end_col": 81, "parent_name": null, "docstring": "Mixin class for exporting validation metrics or prediction results in various formats.\n\nThis class provides utilities to export performance metrics (e.g., mAP, precision, recall) or prediction results\nfrom classification, object detection, segmentation, or pose estimation tasks into various formats: Pandas\nDataFrame, CSV, XML, HTML, JSON and SQLite (SQL).\n\nMethods:\n to_df: Convert summary to a Pandas DataFrame.\n to_csv: Export results as a CSV string.\n to_xml: Export results as an XML string (requires `lxml`).\n to_html: Export results as an HTML table.\n to_json: Export results as a JSON string.\n tojson: Deprecated alias for `to_json()`.\n to_sql: Export results to an SQLite database.\n\nExamples:\n >>> model = YOLO(\"yolo11n.pt\")\n >>> results = model(\"image.jpg\")\n >>> df = results.to_df()\n >>> print(df)\n >>> csv_data = results.to_csv()\n >>> results.to_sql(table_name=\"yolo_results\")", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "contextlib", "importlib.metadata", "inspect", "json", "logging", "os", "platform", "re", "subprocess", "sys", "threading", "time", "warnings", "pathlib.Path", "threading.Lock", "types.SimpleNamespace", "typing.Union", "urllib.parse.unquote", "cv2", "numpy", "torch", "tqdm", "ultralytics.__version__", "ultralytics.utils.patches.imread", "ultralytics.utils.patches.imshow", "ultralytics.utils.patches.imwrite", "ultralytics.utils.patches.torch_save", "tqdm.rich", "importlib.util", "pandas", "sqlite3", "functools.wraps", "yaml", "socket", "sentry_sdk", "hashlib", "uuid", "ultralytics.utils.torch_utils.torch_distributed_zero_first", "matplotlib.pyplot", "io" ], "chunk_id": "class_DataExportMixin_9a9b7f10" }, { "content": "class SimpleClass:\n \"\"\"\n A simple base class for creating objects with string representations of their attributes.\n\n This class provides a foundation for creating objects that can be easily printed or represented as strings,\n showing all their non-callable attributes. It's useful for debugging and introspection of object states.\n\n Methods:\n __str__: Return a human-readable string representation of the object.\n __repr__: Return a machine-readable string representation of the object.\n __getattr__: Provide a custom attribute access error message with helpful information.\n\n Examples:\n >>> class MyClass(SimpleClass):\n ... def __init__(self):\n ... self.x = 10\n ... self.y = \"hello\"\n >>> obj = MyClass()\n >>> print(obj)\n __main__.MyClass object with attributes:\n\n x: 10\n y: 'hello'\n\n Notes:\n - This class is designed to be subclassed. It provides a convenient way to inspect object attributes.\n - The string representation includes the module and class name of the object.\n - Callable attributes and attributes starting with an underscore are excluded from the string representation.\n \"\"\"\n\n def __str__(self):\n \"\"\"Return a human-readable string representation of the object.\"\"\"\n attr = []\n for a in dir(self):\n v = getattr(self, a)\n if not callable(v) and not a.startswith(\"_\"):\n if isinstance(v, SimpleClass):\n # Display only the module and class name for subclasses\n s = f\"{a}: {v.__module__}.{v.__class__.__name__} object\"\n else:\n s = f\"{a}: {repr(v)}\"\n attr.append(s)\n return f\"{self.__module__}.{self.__class__.__name__} object with attributes:\\n\\n\" + \"\\n\".join(attr)\n\n def __repr__(self):\n \"\"\"Return a machine-readable string representation of the object.\"\"\"\n return self.__str__()\n\n def __getattr__(self, attr):\n \"\"\"Provide a custom attribute access error message with helpful information.\"\"\"\n name = self.__class__.__name__\n raise AttributeError(f\"'{name}' object has no attribute '{attr}'. See valid attributes below.\\n{self.__doc__}\")", "chunk_type": "class", "name": "SimpleClass", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 340, "end_line": 391, "start_col": 0, "end_col": 119, "parent_name": null, "docstring": "A simple base class for creating objects with string representations of their attributes.\n\nThis class provides a foundation for creating objects that can be easily printed or represented as strings,\nshowing all their non-callable attributes. It's useful for debugging and introspection of object states.\n\nMethods:\n __str__: Return a human-readable string representation of the object.\n __repr__: Return a machine-readable string representation of the object.\n __getattr__: Provide a custom attribute access error message with helpful information.\n\nExamples:\n >>> class MyClass(SimpleClass):\n ... def __init__(self):\n ... self.x = 10\n ... self.y = \"hello\"\n >>> obj = MyClass()\n >>> print(obj)\n __main__.MyClass object with attributes:\n\n x: 10\n y: 'hello'\n\nNotes:\n - This class is designed to be subclassed. It provides a convenient way to inspect object attributes.\n - The string representation includes the module and class name of the object.\n - Callable attributes and attributes starting with an underscore are excluded from the string representation.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "contextlib", "importlib.metadata", "inspect", "json", "logging", "os", "platform", "re", "subprocess", "sys", "threading", "time", "warnings", "pathlib.Path", "threading.Lock", "types.SimpleNamespace", "typing.Union", "urllib.parse.unquote", "cv2", "numpy", "torch", "tqdm", "ultralytics.__version__", "ultralytics.utils.patches.imread", "ultralytics.utils.patches.imshow", "ultralytics.utils.patches.imwrite", "ultralytics.utils.patches.torch_save", "tqdm.rich", "importlib.util", "pandas", "sqlite3", "functools.wraps", "yaml", "socket", "sentry_sdk", "hashlib", "uuid", "ultralytics.utils.torch_utils.torch_distributed_zero_first", "matplotlib.pyplot", "io" ], "chunk_id": "class_SimpleClass_0d66c6b0" }, { "content": "class IterableSimpleNamespace(SimpleNamespace):\n \"\"\"\n An iterable SimpleNamespace class that provides enhanced functionality for attribute access and iteration.\n\n This class extends the SimpleNamespace class with additional methods for iteration, string representation,\n and attribute access. It is designed to be used as a convenient container for storing and accessing\n configuration parameters.\n\n Methods:\n __iter__: Return an iterator of key-value pairs from the namespace's attributes.\n __str__: Return a human-readable string representation of the object.\n __getattr__: Provide a custom attribute access error message with helpful information.\n get: Retrieve the value of a specified key, or a default value if the key doesn't exist.\n\n Examples:\n >>> cfg = IterableSimpleNamespace(a=1, b=2, c=3)\n >>> for k, v in cfg:\n ... print(f\"{k}: {v}\")\n a: 1\n b: 2\n c: 3\n >>> print(cfg)\n a=1\n b=2\n c=3\n >>> cfg.get(\"b\")\n 2\n >>> cfg.get(\"d\", \"default\")\n 'default'\n\n Notes:\n This class is particularly useful for storing configuration parameters in a more accessible\n and iterable format compared to a standard dictionary.\n \"\"\"\n\n def __iter__(self):\n \"\"\"Return an iterator of key-value pairs from the namespace's attributes.\"\"\"\n return iter(vars(self).items())\n\n def __str__(self):\n \"\"\"Return a human-readable string representation of the object.\"\"\"\n return \"\\n\".join(f\"{k}={v}\" for k, v in vars(self).items())\n\n def __getattr__(self, attr):\n \"\"\"Provide a custom attribute access error message with helpful information.\"\"\"\n name = self.__class__.__name__\n raise AttributeError(\n f\"\"\"\n '{name}' object has no attribute '{attr}'. This may be caused by a modified or out of date ultralytics\n 'default.yaml' file.\\nPlease update your code with 'pip install -U ultralytics' and if necessary replace\n {DEFAULT_CFG_PATH} with the latest version from\n https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/default.yaml\n \"\"\"\n )\n\n def get(self, key, default=None):\n \"\"\"Return the value of the specified key if it exists; otherwise, return the default value.\"\"\"\n return getattr(self, key, default)", "chunk_type": "class", "name": "IterableSimpleNamespace", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 394, "end_line": 451, "start_col": 0, "end_col": 42, "parent_name": null, "docstring": "An iterable SimpleNamespace class that provides enhanced functionality for attribute access and iteration.\n\nThis class extends the SimpleNamespace class with additional methods for iteration, string representation,\nand attribute access. It is designed to be used as a convenient container for storing and accessing\nconfiguration parameters.\n\nMethods:\n __iter__: Return an iterator of key-value pairs from the namespace's attributes.\n __str__: Return a human-readable string representation of the object.\n __getattr__: Provide a custom attribute access error message with helpful information.\n get: Retrieve the value of a specified key, or a default value if the key doesn't exist.\n\nExamples:\n >>> cfg = IterableSimpleNamespace(a=1, b=2, c=3)\n >>> for k, v in cfg:\n ... print(f\"{k}: {v}\")\n a: 1\n b: 2\n c: 3\n >>> print(cfg)\n a=1\n b=2\n c=3\n >>> cfg.get(\"b\")\n 2\n >>> cfg.get(\"d\", \"default\")\n 'default'\n\nNotes:\n This class is particularly useful for storing configuration parameters in a more accessible\n and iterable format compared to a standard dictionary.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "contextlib", "importlib.metadata", "inspect", "json", "logging", "os", "platform", "re", "subprocess", "sys", "threading", "time", "warnings", "pathlib.Path", "threading.Lock", "types.SimpleNamespace", "typing.Union", "urllib.parse.unquote", "cv2", "numpy", "torch", "tqdm", "ultralytics.__version__", "ultralytics.utils.patches.imread", "ultralytics.utils.patches.imshow", "ultralytics.utils.patches.imwrite", "ultralytics.utils.patches.torch_save", "tqdm.rich", "importlib.util", "pandas", "sqlite3", "functools.wraps", "yaml", "socket", "sentry_sdk", "hashlib", "uuid", "ultralytics.utils.torch_utils.torch_distributed_zero_first", "matplotlib.pyplot", "io", "SimpleNamespace" ], "chunk_id": "class_IterableSimpleNamespace_68d9d1c1" }, { "content": "def plt_settings(rcparams=None, backend=\"Agg\"):\n \"\"\"\n Decorator to temporarily set rc parameters and the backend for a plotting function.\n\n Args:\n rcparams (dict, optional): Dictionary of rc parameters to set.\n backend (str, optional): Name of the backend to use.\n\n Returns:\n (Callable): Decorated function with temporarily set rc parameters and backend.\n\n Examples:\n >>> @plt_settings({\"font.size\": 12})\n >>> def plot_function():\n ... plt.figure()\n ... plt.plot([1, 2, 3])\n ... plt.show()\n\n >>> with plt_settings({\"font.size\": 12}):\n ... plt.figure()\n ... plt.plot([1, 2, 3])\n ... plt.show()\n \"\"\"\n if rcparams is None:\n rcparams = {\"font.size\": 11}\n\n def decorator(func):\n \"\"\"Decorator to apply temporary rc parameters and backend to a function.\"\"\"\n\n def wrapper(*args, **kwargs):\n \"\"\"Set rc parameters and backend, call the original function, and restore the settings.\"\"\"\n import matplotlib.pyplot as plt # scope for faster 'import ultralytics'\n\n original_backend = plt.get_backend()\n switch = backend.lower() != original_backend.lower()\n if switch:\n plt.close(\"all\") # auto-close()ing of figures upon backend switching is deprecated since 3.8\n plt.switch_backend(backend)\n\n # Plot with backend and always revert to original backend\n try:\n with plt.rc_context(rcparams):\n result = func(*args, **kwargs)\n finally:\n if switch:\n plt.close(\"all\")\n plt.switch_backend(original_backend)\n return result\n\n return wrapper\n\n return decorator", "chunk_type": "function", "name": "plt_settings", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 454, "end_line": 505, "start_col": 0, "end_col": 20, "parent_name": null, "docstring": "Decorator to temporarily set rc parameters and the backend for a plotting function.\n\nArgs:\n rcparams (dict, optional): Dictionary of rc parameters to set.\n backend (str, optional): Name of the backend to use.\n\nReturns:\n (Callable): Decorated function with temporarily set rc parameters and backend.\n\nExamples:\n >>> @plt_settings({\"font.size\": 12})\n >>> def plot_function():\n ... plt.figure()\n ... plt.plot([1, 2, 3])\n ... plt.show()\n\n >>> with plt_settings({\"font.size\": 12}):\n ... plt.figure()\n ... plt.plot([1, 2, 3])\n ... plt.show()", "parameters": [ "rcparams", "backend" ], "return_type": null, "decorators": [], "complexity_score": 4, "dependencies": [ "contextlib", "importlib.metadata", "inspect", "json", "logging", "os", "platform", "re", "subprocess", "sys", "threading", "time", "warnings", "pathlib.Path", "threading.Lock", "types.SimpleNamespace", "typing.Union", "urllib.parse.unquote", "cv2", "numpy", "torch", "tqdm", "ultralytics.__version__", "ultralytics.utils.patches.imread", "ultralytics.utils.patches.imshow", "ultralytics.utils.patches.imwrite", "ultralytics.utils.patches.torch_save", "tqdm.rich", "importlib.util", "pandas", "sqlite3", "functools.wraps", "yaml", "socket", "sentry_sdk", "hashlib", "uuid", "ultralytics.utils.torch_utils.torch_distributed_zero_first", "matplotlib.pyplot", "io" ], "chunk_id": "function_plt_settings_6860eb76" }, { "content": "def set_logging(name=\"LOGGING_NAME\", verbose=True):\n \"\"\"\n Set up logging with UTF-8 encoding and configurable verbosity.\n\n This function configures logging for the Ultralytics library, setting the appropriate logging level and\n formatter based on the verbosity flag and the current process rank. It handles special cases for Windows\n environments where UTF-8 encoding might not be the default.\n\n Args:\n name (str): Name of the logger.\n verbose (bool): Flag to set logging level to INFO if True, ERROR otherwise.\n\n Returns:\n (logging.Logger): Configured logger object.\n\n Examples:\n >>> set_logging(name=\"ultralytics\", verbose=True)\n >>> logger = logging.getLogger(\"ultralytics\")\n >>> logger.info(\"This is an info message\")\n\n Notes:\n - On Windows, this function attempts to reconfigure stdout to use UTF-8 encoding if possible.\n - If reconfiguration is not possible, it falls back to a custom formatter that handles non-UTF-8 environments.\n - The function sets up a StreamHandler with the appropriate formatter and level.\n - The logger's propagate flag is set to False to prevent duplicate logging in parent loggers.\n \"\"\"\n level = logging.INFO if verbose and RANK in {-1, 0} else logging.ERROR # rank in world for Multi-GPU trainings\n\n class PrefixFormatter(logging.Formatter):\n def format(self, record):\n \"\"\"Format log records with prefixes based on level.\"\"\"\n # Apply prefixes based on log level\n if record.levelno == logging.WARNING:\n prefix = \"WARNING ⚠️\" if not WINDOWS else \"WARNING\"\n record.msg = f\"{prefix} {record.msg}\"\n elif record.levelno == logging.ERROR:\n prefix = \"ERROR ❌\" if not WINDOWS else \"ERROR\"\n record.msg = f\"{prefix} {record.msg}\"\n\n # Handle emojis in message based on platform\n formatted_message = super().format(record)\n return emojis(formatted_message)\n\n formatter = PrefixFormatter(\"%(message)s\")\n\n # Handle Windows UTF-8 encoding issues\n if WINDOWS and hasattr(sys.stdout, \"encoding\") and sys.stdout.encoding != \"utf-8\":\n try:\n # Attempt to reconfigure stdout to use UTF-8 encoding if possible\n if hasattr(sys.stdout, \"reconfigure\"):\n sys.stdout.reconfigure(encoding=\"utf-8\")\n # For environments where reconfigure is not available, wrap stdout in a TextIOWrapper\n elif hasattr(sys.stdout, \"buffer\"):\n import io\n\n sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding=\"utf-8\")\n except Exception:\n pass\n\n # Create and configure the StreamHandler with the appropriate formatter and level\n stream_handler = logging.StreamHandler(sys.stdout)\n stream_handler.setFormatter(formatter)\n stream_handler.setLevel(level)\n\n # Set up the logger\n logger = logging.getLogger(name)\n logger.setLevel(level)\n logger.addHandler(stream_handler)\n logger.propagate = False\n return logger", "chunk_type": "function", "name": "set_logging", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 508, "end_line": 577, "start_col": 0, "end_col": 17, "parent_name": null, "docstring": "Set up logging with UTF-8 encoding and configurable verbosity.\n\nThis function configures logging for the Ultralytics library, setting the appropriate logging level and\nformatter based on the verbosity flag and the current process rank. It handles special cases for Windows\nenvironments where UTF-8 encoding might not be the default.\n\nArgs:\n name (str): Name of the logger.\n verbose (bool): Flag to set logging level to INFO if True, ERROR otherwise.\n\nReturns:\n (logging.Logger): Configured logger object.\n\nExamples:\n >>> set_logging(name=\"ultralytics\", verbose=True)\n >>> logger = logging.getLogger(\"ultralytics\")\n >>> logger.info(\"This is an info message\")\n\nNotes:\n - On Windows, this function attempts to reconfigure stdout to use UTF-8 encoding if possible.\n - If reconfiguration is not possible, it falls back to a custom formatter that handles non-UTF-8 environments.\n - The function sets up a StreamHandler with the appropriate formatter and level.\n - The logger's propagate flag is set to False to prevent duplicate logging in parent loggers.", "parameters": [ "name", "verbose" ], "return_type": null, "decorators": [], "complexity_score": 7, "dependencies": [ "contextlib", "importlib.metadata", "inspect", "json", "logging", "os", "platform", "re", "subprocess", "sys", "threading", "time", "warnings", "pathlib.Path", "threading.Lock", "types.SimpleNamespace", "typing.Union", "urllib.parse.unquote", "cv2", "numpy", "torch", "tqdm", "ultralytics.__version__", "ultralytics.utils.patches.imread", "ultralytics.utils.patches.imshow", "ultralytics.utils.patches.imwrite", "ultralytics.utils.patches.torch_save", "tqdm.rich", "importlib.util", "pandas", "sqlite3", "functools.wraps", "yaml", "socket", "sentry_sdk", "hashlib", "uuid", "ultralytics.utils.torch_utils.torch_distributed_zero_first", "matplotlib.pyplot", "io" ], "chunk_id": "function_set_logging_e62f00e3" }, { "content": "LOGGER = set_logging(LOGGING_NAME, verbose=VERBOSE) # define globally (used in train.py, val.py, predict.py, etc.)", "chunk_type": "variable", "name": "LOGGER", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 581, "end_line": 581, "start_col": 0, "end_col": 51, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_LOGGER_d3f1010c" }, { "content": "def emojis(string=\"\"):\n \"\"\"Return platform-dependent emoji-safe version of string.\"\"\"\n return string.encode().decode(\"ascii\", \"ignore\") if WINDOWS else string", "chunk_type": "function", "name": "emojis", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 586, "end_line": 588, "start_col": 0, "end_col": 75, "parent_name": null, "docstring": "Return platform-dependent emoji-safe version of string.", "parameters": [ "string" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "contextlib", "importlib.metadata", "inspect", "json", "logging", "os", "platform", "re", "subprocess", "sys", "threading", "time", "warnings", "pathlib.Path", "threading.Lock", "types.SimpleNamespace", "typing.Union", "urllib.parse.unquote", "cv2", "numpy", "torch", "tqdm", "ultralytics.__version__", "ultralytics.utils.patches.imread", "ultralytics.utils.patches.imshow", "ultralytics.utils.patches.imwrite", "ultralytics.utils.patches.torch_save", "tqdm.rich", "importlib.util", "pandas", "sqlite3", "functools.wraps", "yaml", "socket", "sentry_sdk", "hashlib", "uuid", "ultralytics.utils.torch_utils.torch_distributed_zero_first", "matplotlib.pyplot", "io" ], "chunk_id": "function_emojis_37a9f4c6" }, { "content": "class ThreadingLocked:\n \"\"\"\n A decorator class for ensuring thread-safe execution of a function or method.\n\n This class can be used as a decorator to make sure that if the decorated function is called from multiple threads,\n only one thread at a time will be able to execute the function.\n\n Attributes:\n lock (threading.Lock): A lock object used to manage access to the decorated function.\n\n Examples:\n >>> from ultralytics.utils import ThreadingLocked\n >>> @ThreadingLocked()\n >>> def my_function():\n ... # Your code here\n \"\"\"\n\n def __init__(self):\n \"\"\"Initialize the decorator class with a threading lock.\"\"\"\n self.lock = threading.Lock()\n\n def __call__(self, f):\n \"\"\"Run thread-safe execution of function or method.\"\"\"\n from functools import wraps\n\n @wraps(f)\n def decorated(*args, **kwargs):\n \"\"\"Apply thread-safety to the decorated function or method.\"\"\"\n with self.lock:\n return f(*args, **kwargs)\n\n return decorated", "chunk_type": "class", "name": "ThreadingLocked", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 591, "end_line": 622, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": "A decorator class for ensuring thread-safe execution of a function or method.\n\nThis class can be used as a decorator to make sure that if the decorated function is called from multiple threads,\nonly one thread at a time will be able to execute the function.\n\nAttributes:\n lock (threading.Lock): A lock object used to manage access to the decorated function.\n\nExamples:\n >>> from ultralytics.utils import ThreadingLocked\n >>> @ThreadingLocked()\n >>> def my_function():\n ... # Your code here", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "contextlib", "importlib.metadata", "inspect", "json", "logging", "os", "platform", "re", "subprocess", "sys", "threading", "time", "warnings", "pathlib.Path", "threading.Lock", "types.SimpleNamespace", "typing.Union", "urllib.parse.unquote", "cv2", "numpy", "torch", "tqdm", "ultralytics.__version__", "ultralytics.utils.patches.imread", "ultralytics.utils.patches.imshow", "ultralytics.utils.patches.imwrite", "ultralytics.utils.patches.torch_save", "tqdm.rich", "importlib.util", "pandas", "sqlite3", "functools.wraps", "yaml", "socket", "sentry_sdk", "hashlib", "uuid", "ultralytics.utils.torch_utils.torch_distributed_zero_first", "matplotlib.pyplot", "io" ], "chunk_id": "class_ThreadingLocked_a0571b88" }, { "content": "class YAML:\n \"\"\"\n YAML utility class for efficient file operations with automatic C-implementation detection.\n\n This class provides optimized YAML loading and saving operations using PyYAML's fastest available implementation\n (C-based when possible). It implements a singleton pattern with lazy initialization, allowing direct class method\n usage without explicit instantiation. The class handles file path creation, validation, and character encoding\n issues automatically.\n\n The implementation prioritizes performance through:\n - Automatic C-based loader/dumper selection when available\n - Singleton pattern to reuse the same instance\n - Lazy initialization to defer import costs until needed\n - Fallback mechanisms for handling problematic YAML content\n\n Attributes:\n _instance: Internal singleton instance storage.\n yaml: Reference to the PyYAML module.\n SafeLoader: Best available YAML loader (CSafeLoader if available).\n SafeDumper: Best available YAML dumper (CSafeDumper if available).\n\n Examples:\n >>> data = YAML.load(\"config.yaml\")\n >>> data[\"new_value\"] = 123\n >>> YAML.save(\"updated_config.yaml\", data)\n >>> YAML.print(data)\n \"\"\"\n\n _instance = None\n\n @classmethod\n def _get_instance(cls):\n \"\"\"Initialize singleton instance on first use.\"\"\"\n if cls._instance is None:\n cls._instance = cls()\n return cls._instance\n\n def __init__(self):\n \"\"\"Initialize with optimal YAML implementation (C-based when available).\"\"\"\n import yaml\n\n self.yaml = yaml\n # Use C-based implementation if available for better performance\n try:\n self.SafeLoader = yaml.CSafeLoader\n self.SafeDumper = yaml.CSafeDumper\n except (AttributeError, ImportError):\n self.SafeLoader = yaml.SafeLoader\n self.SafeDumper = yaml.SafeDumper\n\n @classmethod\n def save(cls, file=\"data.yaml\", data=None, header=\"\"):\n \"\"\"\n Save Python object as YAML file.\n\n Args:\n file (str | Path): Path to save YAML file.\n data (dict | None): Dict or compatible object to save.\n header (str): Optional string to add at file beginning.\n \"\"\"\n instance = cls._get_instance()\n if data is None:\n data = {}\n\n # Create parent directories if needed\n file = Path(file)\n file.parent.mkdir(parents=True, exist_ok=True)\n\n # Convert non-serializable objects to strings\n valid_types = int, float, str, bool, list, tuple, dict, type(None)\n for k, v in data.items():\n if not isinstance(v, valid_types):\n data[k] = str(v)\n\n # Write YAML file\n with open(file, \"w\", errors=\"ignore\", encoding=\"utf-8\") as f:\n if header:\n f.write(header)\n instance.yaml.dump(data, f, sort_keys=False, allow_unicode=True, Dumper=instance.SafeDumper)\n\n @classmethod\n def load(cls, file=\"data.yaml\", append_filename=False):\n \"\"\"\n Load YAML file to Python object with robust error handling.\n\n Args:\n file (str | Path): Path to YAML file.\n append_filename (bool): Whether to add filename to returned dict.\n\n Returns:\n (dict): Loaded YAML content.\n \"\"\"\n instance = cls._get_instance()\n assert str(file).endswith((\".yaml\", \".yml\")), f\"Not a YAML file: {file}\"\n\n # Read file content\n with open(file, errors=\"ignore\", encoding=\"utf-8\") as f:\n s = f.read()\n\n # Try loading YAML with fallback for problematic characters\n try:\n data = instance.yaml.load(s, Loader=instance.SafeLoader) or {}\n except Exception:\n # Remove problematic characters and retry\n s = re.sub(r\"[^\\x09\\x0A\\x0D\\x20-\\x7E\\x85\\xA0-\\uD7FF\\uE000-\\uFFFD\\U00010000-\\U0010ffff]+\", \"\", s)\n data = instance.yaml.load(s, Loader=instance.SafeLoader) or {}\n\n # Check for accidental user-error None strings (should be 'null' in YAML)\n if \"None\" in data.values():\n data = {k: None if v == \"None\" else v for k, v in data.items()}\n\n if append_filename:\n data[\"yaml_file\"] = str(file)\n return data\n\n @classmethod\n def print(cls, yaml_file):\n \"\"\"\n Pretty print YAML file or object to console.\n\n Args:\n yaml_file (str | Path | dict): Path to YAML file or dict to print.\n \"\"\"\n instance = cls._get_instance()\n\n # Load file if path provided\n yaml_dict = cls.load(yaml_file) if isinstance(yaml_file, (str, Path)) else yaml_file\n\n # Use -1 for unlimited width in C implementation\n dump = instance.yaml.dump(yaml_dict, sort_keys=False, allow_unicode=True, width=-1, Dumper=instance.SafeDumper)\n\n LOGGER.info(f\"Printing '{colorstr('bold', 'black', yaml_file)}'\\n\\n{dump}\")", "chunk_type": "class", "name": "YAML", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 625, "end_line": 756, "start_col": 0, "end_col": 83, "parent_name": null, "docstring": "YAML utility class for efficient file operations with automatic C-implementation detection.\n\nThis class provides optimized YAML loading and saving operations using PyYAML's fastest available implementation\n(C-based when possible). It implements a singleton pattern with lazy initialization, allowing direct class method\nusage without explicit instantiation. The class handles file path creation, validation, and character encoding\nissues automatically.\n\nThe implementation prioritizes performance through:\n - Automatic C-based loader/dumper selection when available\n - Singleton pattern to reuse the same instance\n - Lazy initialization to defer import costs until needed\n - Fallback mechanisms for handling problematic YAML content\n\nAttributes:\n _instance: Internal singleton instance storage.\n yaml: Reference to the PyYAML module.\n SafeLoader: Best available YAML loader (CSafeLoader if available).\n SafeDumper: Best available YAML dumper (CSafeDumper if available).\n\nExamples:\n >>> data = YAML.load(\"config.yaml\")\n >>> data[\"new_value\"] = 123\n >>> YAML.save(\"updated_config.yaml\", data)\n >>> YAML.print(data)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "contextlib", "importlib.metadata", "inspect", "json", "logging", "os", "platform", "re", "subprocess", "sys", "threading", "time", "warnings", "pathlib.Path", "threading.Lock", "types.SimpleNamespace", "typing.Union", "urllib.parse.unquote", "cv2", "numpy", "torch", "tqdm", "ultralytics.__version__", "ultralytics.utils.patches.imread", "ultralytics.utils.patches.imshow", "ultralytics.utils.patches.imwrite", "ultralytics.utils.patches.torch_save", "tqdm.rich", "importlib.util", "pandas", "sqlite3", "functools.wraps", "yaml", "socket", "sentry_sdk", "hashlib", "uuid", "ultralytics.utils.torch_utils.torch_distributed_zero_first", "matplotlib.pyplot", "io" ], "chunk_id": "class_YAML_8935df41" }, { "content": "DEFAULT_CFG_DICT = YAML.load(DEFAULT_CFG_PATH)", "chunk_type": "variable", "name": "DEFAULT_CFG_DICT", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 760, "end_line": 760, "start_col": 0, "end_col": 46, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_DEFAULT_CFG_DICT_07c66c33" }, { "content": "DEFAULT_CFG_KEYS = DEFAULT_CFG_DICT.keys()", "chunk_type": "variable", "name": "DEFAULT_CFG_KEYS", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 761, "end_line": 761, "start_col": 0, "end_col": 42, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_DEFAULT_CFG_KEYS_971a588d" }, { "content": "DEFAULT_CFG = IterableSimpleNamespace(**DEFAULT_CFG_DICT)", "chunk_type": "variable", "name": "DEFAULT_CFG", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 762, "end_line": 762, "start_col": 0, "end_col": 57, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_DEFAULT_CFG_2b455215" }, { "content": "def read_device_model() -> str:\n \"\"\"\n Read the device model information from the system and cache it for quick access.\n\n Returns:\n (str): Kernel release information.\n \"\"\"\n return platform.release().lower()", "chunk_type": "function", "name": "read_device_model", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 765, "end_line": 772, "start_col": 0, "end_col": 37, "parent_name": null, "docstring": "Read the device model information from the system and cache it for quick access.\n\nReturns:\n (str): Kernel release information.", "parameters": [], "return_type": "str", "decorators": [], "complexity_score": 1, "dependencies": [ "contextlib", "importlib.metadata", "inspect", "json", "logging", "os", "platform", "re", "subprocess", "sys", "threading", "time", "warnings", "pathlib.Path", "threading.Lock", "types.SimpleNamespace", "typing.Union", "urllib.parse.unquote", "cv2", "numpy", "torch", "tqdm", "ultralytics.__version__", "ultralytics.utils.patches.imread", "ultralytics.utils.patches.imshow", "ultralytics.utils.patches.imwrite", "ultralytics.utils.patches.torch_save", "tqdm.rich", "importlib.util", "pandas", "sqlite3", "functools.wraps", "yaml", "socket", "sentry_sdk", "hashlib", "uuid", "ultralytics.utils.torch_utils.torch_distributed_zero_first", "matplotlib.pyplot", "io" ], "chunk_id": "function_read_device_model_958dbac8" }, { "content": "def is_ubuntu() -> bool:\n \"\"\"\n Check if the OS is Ubuntu.\n\n Returns:\n (bool): True if OS is Ubuntu, False otherwise.\n \"\"\"\n try:\n with open(\"/etc/os-release\") as f:\n return \"ID=ubuntu\" in f.read()\n except FileNotFoundError:\n return False", "chunk_type": "function", "name": "is_ubuntu", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 775, "end_line": 786, "start_col": 0, "end_col": 20, "parent_name": null, "docstring": "Check if the OS is Ubuntu.\n\nReturns:\n (bool): True if OS is Ubuntu, False otherwise.", "parameters": [], "return_type": "bool", "decorators": [], "complexity_score": 2, "dependencies": [ "contextlib", "importlib.metadata", "inspect", "json", "logging", "os", "platform", "re", "subprocess", "sys", "threading", "time", "warnings", "pathlib.Path", "threading.Lock", "types.SimpleNamespace", "typing.Union", "urllib.parse.unquote", "cv2", "numpy", "torch", "tqdm", "ultralytics.__version__", "ultralytics.utils.patches.imread", "ultralytics.utils.patches.imshow", "ultralytics.utils.patches.imwrite", "ultralytics.utils.patches.torch_save", "tqdm.rich", "importlib.util", "pandas", "sqlite3", "functools.wraps", "yaml", "socket", "sentry_sdk", "hashlib", "uuid", "ultralytics.utils.torch_utils.torch_distributed_zero_first", "matplotlib.pyplot", "io" ], "chunk_id": "function_is_ubuntu_49c5d0a0" }, { "content": "def is_colab():\n \"\"\"\n Check if the current script is running inside a Google Colab notebook.\n\n Returns:\n (bool): True if running inside a Colab notebook, False otherwise.\n \"\"\"\n return \"COLAB_RELEASE_TAG\" in os.environ or \"COLAB_BACKEND_VERSION\" in os.environ", "chunk_type": "function", "name": "is_colab", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 789, "end_line": 796, "start_col": 0, "end_col": 85, "parent_name": null, "docstring": "Check if the current script is running inside a Google Colab notebook.\n\nReturns:\n (bool): True if running inside a Colab notebook, False otherwise.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "contextlib", "importlib.metadata", "inspect", "json", "logging", "os", "platform", "re", "subprocess", "sys", "threading", "time", "warnings", "pathlib.Path", "threading.Lock", "types.SimpleNamespace", "typing.Union", "urllib.parse.unquote", "cv2", "numpy", "torch", "tqdm", "ultralytics.__version__", "ultralytics.utils.patches.imread", "ultralytics.utils.patches.imshow", "ultralytics.utils.patches.imwrite", "ultralytics.utils.patches.torch_save", "tqdm.rich", "importlib.util", "pandas", "sqlite3", "functools.wraps", "yaml", "socket", "sentry_sdk", "hashlib", "uuid", "ultralytics.utils.torch_utils.torch_distributed_zero_first", "matplotlib.pyplot", "io" ], "chunk_id": "function_is_colab_625ca7e5" }, { "content": "def is_kaggle():\n \"\"\"\n Check if the current script is running inside a Kaggle kernel.\n\n Returns:\n (bool): True if running inside a Kaggle kernel, False otherwise.\n \"\"\"\n return os.environ.get(\"PWD\") == \"/kaggle/working\" and os.environ.get(\"KAGGLE_URL_BASE\") == \"https://www.kaggle.com\"", "chunk_type": "function", "name": "is_kaggle", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 799, "end_line": 806, "start_col": 0, "end_col": 119, "parent_name": null, "docstring": "Check if the current script is running inside a Kaggle kernel.\n\nReturns:\n (bool): True if running inside a Kaggle kernel, False otherwise.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "contextlib", "importlib.metadata", "inspect", "json", "logging", "os", "platform", "re", "subprocess", "sys", "threading", "time", "warnings", "pathlib.Path", "threading.Lock", "types.SimpleNamespace", "typing.Union", "urllib.parse.unquote", "cv2", "numpy", "torch", "tqdm", "ultralytics.__version__", "ultralytics.utils.patches.imread", "ultralytics.utils.patches.imshow", "ultralytics.utils.patches.imwrite", "ultralytics.utils.patches.torch_save", "tqdm.rich", "importlib.util", "pandas", "sqlite3", "functools.wraps", "yaml", "socket", "sentry_sdk", "hashlib", "uuid", "ultralytics.utils.torch_utils.torch_distributed_zero_first", "matplotlib.pyplot", "io" ], "chunk_id": "function_is_kaggle_57045760" }, { "content": "def is_jupyter():\n \"\"\"\n Check if the current script is running inside a Jupyter Notebook.\n\n Returns:\n (bool): True if running inside a Jupyter Notebook, False otherwise.\n\n Notes:\n - Only works on Colab and Kaggle, other environments like Jupyterlab and Paperspace are not reliably detectable.\n - \"get_ipython\" in globals() method suffers false positives when IPython package installed manually.\n \"\"\"\n return IS_COLAB or IS_KAGGLE", "chunk_type": "function", "name": "is_jupyter", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 809, "end_line": 820, "start_col": 0, "end_col": 32, "parent_name": null, "docstring": "Check if the current script is running inside a Jupyter Notebook.\n\nReturns:\n (bool): True if running inside a Jupyter Notebook, False otherwise.\n\nNotes:\n - Only works on Colab and Kaggle, other environments like Jupyterlab and Paperspace are not reliably detectable.\n - \"get_ipython\" in globals() method suffers false positives when IPython package installed manually.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "contextlib", "importlib.metadata", "inspect", "json", "logging", "os", "platform", "re", "subprocess", "sys", "threading", "time", "warnings", "pathlib.Path", "threading.Lock", "types.SimpleNamespace", "typing.Union", "urllib.parse.unquote", "cv2", "numpy", "torch", "tqdm", "ultralytics.__version__", "ultralytics.utils.patches.imread", "ultralytics.utils.patches.imshow", "ultralytics.utils.patches.imwrite", "ultralytics.utils.patches.torch_save", "tqdm.rich", "importlib.util", "pandas", "sqlite3", "functools.wraps", "yaml", "socket", "sentry_sdk", "hashlib", "uuid", "ultralytics.utils.torch_utils.torch_distributed_zero_first", "matplotlib.pyplot", "io" ], "chunk_id": "function_is_jupyter_2cf61eed" }, { "content": "def is_runpod():\n \"\"\"\n Check if the current script is running inside a RunPod container.\n\n Returns:\n (bool): True if running in RunPod, False otherwise.\n \"\"\"\n return \"RUNPOD_POD_ID\" in os.environ", "chunk_type": "function", "name": "is_runpod", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 823, "end_line": 830, "start_col": 0, "end_col": 40, "parent_name": null, "docstring": "Check if the current script is running inside a RunPod container.\n\nReturns:\n (bool): True if running in RunPod, False otherwise.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "contextlib", "importlib.metadata", "inspect", "json", "logging", "os", "platform", "re", "subprocess", "sys", "threading", "time", "warnings", "pathlib.Path", "threading.Lock", "types.SimpleNamespace", "typing.Union", "urllib.parse.unquote", "cv2", "numpy", "torch", "tqdm", "ultralytics.__version__", "ultralytics.utils.patches.imread", "ultralytics.utils.patches.imshow", "ultralytics.utils.patches.imwrite", "ultralytics.utils.patches.torch_save", "tqdm.rich", "importlib.util", "pandas", "sqlite3", "functools.wraps", "yaml", "socket", "sentry_sdk", "hashlib", "uuid", "ultralytics.utils.torch_utils.torch_distributed_zero_first", "matplotlib.pyplot", "io" ], "chunk_id": "function_is_runpod_a92259b9" }, { "content": "def is_docker() -> bool:\n \"\"\"\n Determine if the script is running inside a Docker container.\n\n Returns:\n (bool): True if the script is running inside a Docker container, False otherwise.\n \"\"\"\n try:\n return os.path.exists(\"/.dockerenv\")\n except Exception:\n return False", "chunk_type": "function", "name": "is_docker", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 833, "end_line": 843, "start_col": 0, "end_col": 20, "parent_name": null, "docstring": "Determine if the script is running inside a Docker container.\n\nReturns:\n (bool): True if the script is running inside a Docker container, False otherwise.", "parameters": [], "return_type": "bool", "decorators": [], "complexity_score": 2, "dependencies": [ "contextlib", "importlib.metadata", "inspect", "json", "logging", "os", "platform", "re", "subprocess", "sys", "threading", "time", "warnings", "pathlib.Path", "threading.Lock", "types.SimpleNamespace", "typing.Union", "urllib.parse.unquote", "cv2", "numpy", "torch", "tqdm", "ultralytics.__version__", "ultralytics.utils.patches.imread", "ultralytics.utils.patches.imshow", "ultralytics.utils.patches.imwrite", "ultralytics.utils.patches.torch_save", "tqdm.rich", "importlib.util", "pandas", "sqlite3", "functools.wraps", "yaml", "socket", "sentry_sdk", "hashlib", "uuid", "ultralytics.utils.torch_utils.torch_distributed_zero_first", "matplotlib.pyplot", "io" ], "chunk_id": "function_is_docker_dcc1b9ae" }, { "content": "def is_raspberrypi() -> bool:\n \"\"\"\n Determine if the Python environment is running on a Raspberry Pi.\n\n Returns:\n (bool): True if running on a Raspberry Pi, False otherwise.\n \"\"\"\n return \"rpi\" in DEVICE_MODEL", "chunk_type": "function", "name": "is_raspberrypi", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 846, "end_line": 853, "start_col": 0, "end_col": 32, "parent_name": null, "docstring": "Determine if the Python environment is running on a Raspberry Pi.\n\nReturns:\n (bool): True if running on a Raspberry Pi, False otherwise.", "parameters": [], "return_type": "bool", "decorators": [], "complexity_score": 1, "dependencies": [ "contextlib", "importlib.metadata", "inspect", "json", "logging", "os", "platform", "re", "subprocess", "sys", "threading", "time", "warnings", "pathlib.Path", "threading.Lock", "types.SimpleNamespace", "typing.Union", "urllib.parse.unquote", "cv2", "numpy", "torch", "tqdm", "ultralytics.__version__", "ultralytics.utils.patches.imread", "ultralytics.utils.patches.imshow", "ultralytics.utils.patches.imwrite", "ultralytics.utils.patches.torch_save", "tqdm.rich", "importlib.util", "pandas", "sqlite3", "functools.wraps", "yaml", "socket", "sentry_sdk", "hashlib", "uuid", "ultralytics.utils.torch_utils.torch_distributed_zero_first", "matplotlib.pyplot", "io" ], "chunk_id": "function_is_raspberrypi_2a530ff0" }, { "content": "def is_jetson() -> bool:\n \"\"\"\n Determine if the Python environment is running on an NVIDIA Jetson device.\n\n Returns:\n (bool): True if running on an NVIDIA Jetson device, False otherwise.\n \"\"\"\n return \"tegra\" in DEVICE_MODEL", "chunk_type": "function", "name": "is_jetson", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 856, "end_line": 863, "start_col": 0, "end_col": 34, "parent_name": null, "docstring": "Determine if the Python environment is running on an NVIDIA Jetson device.\n\nReturns:\n (bool): True if running on an NVIDIA Jetson device, False otherwise.", "parameters": [], "return_type": "bool", "decorators": [], "complexity_score": 1, "dependencies": [ "contextlib", "importlib.metadata", "inspect", "json", "logging", "os", "platform", "re", "subprocess", "sys", "threading", "time", "warnings", "pathlib.Path", "threading.Lock", "types.SimpleNamespace", "typing.Union", "urllib.parse.unquote", "cv2", "numpy", "torch", "tqdm", "ultralytics.__version__", "ultralytics.utils.patches.imread", "ultralytics.utils.patches.imshow", "ultralytics.utils.patches.imwrite", "ultralytics.utils.patches.torch_save", "tqdm.rich", "importlib.util", "pandas", "sqlite3", "functools.wraps", "yaml", "socket", "sentry_sdk", "hashlib", "uuid", "ultralytics.utils.torch_utils.torch_distributed_zero_first", "matplotlib.pyplot", "io" ], "chunk_id": "function_is_jetson_28a84d17" }, { "content": "def is_online() -> bool:\n \"\"\"\n Check internet connectivity by attempting to connect to a known online host.\n\n Returns:\n (bool): True if connection is successful, False otherwise.\n \"\"\"\n try:\n assert str(os.getenv(\"YOLO_OFFLINE\", \"\")).lower() != \"true\" # check if ENV var YOLO_OFFLINE=\"True\"\n import socket\n\n for dns in (\"1.1.1.1\", \"8.8.8.8\"): # check Cloudflare and Google DNS\n socket.create_connection(address=(dns, 80), timeout=2.0).close()\n return True\n except Exception:\n return False", "chunk_type": "function", "name": "is_online", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 866, "end_line": 881, "start_col": 0, "end_col": 20, "parent_name": null, "docstring": "Check internet connectivity by attempting to connect to a known online host.\n\nReturns:\n (bool): True if connection is successful, False otherwise.", "parameters": [], "return_type": "bool", "decorators": [], "complexity_score": 3, "dependencies": [ "contextlib", "importlib.metadata", "inspect", "json", "logging", "os", "platform", "re", "subprocess", "sys", "threading", "time", "warnings", "pathlib.Path", "threading.Lock", "types.SimpleNamespace", "typing.Union", "urllib.parse.unquote", "cv2", "numpy", "torch", "tqdm", "ultralytics.__version__", "ultralytics.utils.patches.imread", "ultralytics.utils.patches.imshow", "ultralytics.utils.patches.imwrite", "ultralytics.utils.patches.torch_save", "tqdm.rich", "importlib.util", "pandas", "sqlite3", "functools.wraps", "yaml", "socket", "sentry_sdk", "hashlib", "uuid", "ultralytics.utils.torch_utils.torch_distributed_zero_first", "matplotlib.pyplot", "io" ], "chunk_id": "function_is_online_9e65e651" }, { "content": "def is_pip_package(filepath: str = __name__) -> bool:\n \"\"\"\n Determine if the file at the given filepath is part of a pip package.\n\n Args:\n filepath (str): The filepath to check.\n\n Returns:\n (bool): True if the file is part of a pip package, False otherwise.\n \"\"\"\n import importlib.util\n\n # Get the spec for the module\n spec = importlib.util.find_spec(filepath)\n\n # Return whether the spec is not None and the origin is not None (indicating it is a package)\n return spec is not None and spec.origin is not None", "chunk_type": "function", "name": "is_pip_package", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 884, "end_line": 900, "start_col": 0, "end_col": 55, "parent_name": null, "docstring": "Determine if the file at the given filepath is part of a pip package.\n\nArgs:\n filepath (str): The filepath to check.\n\nReturns:\n (bool): True if the file is part of a pip package, False otherwise.", "parameters": [ "filepath: str" ], "return_type": "bool", "decorators": [], "complexity_score": 1, "dependencies": [ "contextlib", "importlib.metadata", "inspect", "json", "logging", "os", "platform", "re", "subprocess", "sys", "threading", "time", "warnings", "pathlib.Path", "threading.Lock", "types.SimpleNamespace", "typing.Union", "urllib.parse.unquote", "cv2", "numpy", "torch", "tqdm", "ultralytics.__version__", "ultralytics.utils.patches.imread", "ultralytics.utils.patches.imshow", "ultralytics.utils.patches.imwrite", "ultralytics.utils.patches.torch_save", "tqdm.rich", "importlib.util", "pandas", "sqlite3", "functools.wraps", "yaml", "socket", "sentry_sdk", "hashlib", "uuid", "ultralytics.utils.torch_utils.torch_distributed_zero_first", "matplotlib.pyplot", "io" ], "chunk_id": "function_is_pip_package_fc1f9870" }, { "content": "def is_dir_writeable(dir_path: Union[str, Path]) -> bool:\n \"\"\"\n Check if a directory is writeable.\n\n Args:\n dir_path (str | Path): The path to the directory.\n\n Returns:\n (bool): True if the directory is writeable, False otherwise.\n \"\"\"\n return os.access(str(dir_path), os.W_OK)", "chunk_type": "function", "name": "is_dir_writeable", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 903, "end_line": 913, "start_col": 0, "end_col": 44, "parent_name": null, "docstring": "Check if a directory is writeable.\n\nArgs:\n dir_path (str | Path): The path to the directory.\n\nReturns:\n (bool): True if the directory is writeable, False otherwise.", "parameters": [ "dir_path: Union[str, Path]" ], "return_type": "bool", "decorators": [], "complexity_score": 1, "dependencies": [ "contextlib", "importlib.metadata", "inspect", "json", "logging", "os", "platform", "re", "subprocess", "sys", "threading", "time", "warnings", "pathlib.Path", "threading.Lock", "types.SimpleNamespace", "typing.Union", "urllib.parse.unquote", "cv2", "numpy", "torch", "tqdm", "ultralytics.__version__", "ultralytics.utils.patches.imread", "ultralytics.utils.patches.imshow", "ultralytics.utils.patches.imwrite", "ultralytics.utils.patches.torch_save", "tqdm.rich", "importlib.util", "pandas", "sqlite3", "functools.wraps", "yaml", "socket", "sentry_sdk", "hashlib", "uuid", "ultralytics.utils.torch_utils.torch_distributed_zero_first", "matplotlib.pyplot", "io" ], "chunk_id": "function_is_dir_writeable_1802308c" }, { "content": "def is_pytest_running():\n \"\"\"\n Determine whether pytest is currently running or not.\n\n Returns:\n (bool): True if pytest is running, False otherwise.\n \"\"\"\n return (\"PYTEST_CURRENT_TEST\" in os.environ) or (\"pytest\" in sys.modules) or (\"pytest\" in Path(ARGV[0]).stem)", "chunk_type": "function", "name": "is_pytest_running", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 916, "end_line": 923, "start_col": 0, "end_col": 113, "parent_name": null, "docstring": "Determine whether pytest is currently running or not.\n\nReturns:\n (bool): True if pytest is running, False otherwise.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "contextlib", "importlib.metadata", "inspect", "json", "logging", "os", "platform", "re", "subprocess", "sys", "threading", "time", "warnings", "pathlib.Path", "threading.Lock", "types.SimpleNamespace", "typing.Union", "urllib.parse.unquote", "cv2", "numpy", "torch", "tqdm", "ultralytics.__version__", "ultralytics.utils.patches.imread", "ultralytics.utils.patches.imshow", "ultralytics.utils.patches.imwrite", "ultralytics.utils.patches.torch_save", "tqdm.rich", "importlib.util", "pandas", "sqlite3", "functools.wraps", "yaml", "socket", "sentry_sdk", "hashlib", "uuid", "ultralytics.utils.torch_utils.torch_distributed_zero_first", "matplotlib.pyplot", "io" ], "chunk_id": "function_is_pytest_running_b6b77f74" }, { "content": "def is_github_action_running() -> bool:\n \"\"\"\n Determine if the current environment is a GitHub Actions runner.\n\n Returns:\n (bool): True if the current environment is a GitHub Actions runner, False otherwise.\n \"\"\"\n return \"GITHUB_ACTIONS\" in os.environ and \"GITHUB_WORKFLOW\" in os.environ and \"RUNNER_OS\" in os.environ", "chunk_type": "function", "name": "is_github_action_running", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 926, "end_line": 933, "start_col": 0, "end_col": 107, "parent_name": null, "docstring": "Determine if the current environment is a GitHub Actions runner.\n\nReturns:\n (bool): True if the current environment is a GitHub Actions runner, False otherwise.", "parameters": [], "return_type": "bool", "decorators": [], "complexity_score": 1, "dependencies": [ "contextlib", "importlib.metadata", "inspect", "json", "logging", "os", "platform", "re", "subprocess", "sys", "threading", "time", "warnings", "pathlib.Path", "threading.Lock", "types.SimpleNamespace", "typing.Union", "urllib.parse.unquote", "cv2", "numpy", "torch", "tqdm", "ultralytics.__version__", "ultralytics.utils.patches.imread", "ultralytics.utils.patches.imshow", "ultralytics.utils.patches.imwrite", "ultralytics.utils.patches.torch_save", "tqdm.rich", "importlib.util", "pandas", "sqlite3", "functools.wraps", "yaml", "socket", "sentry_sdk", "hashlib", "uuid", "ultralytics.utils.torch_utils.torch_distributed_zero_first", "matplotlib.pyplot", "io" ], "chunk_id": "function_is_github_action_running_689e73f5" }, { "content": "def get_git_dir():\n \"\"\"\n Determine whether the current file is part of a git repository and if so, return the repository root directory.\n\n Returns:\n (Path | None): Git root directory if found or None if not found.\n \"\"\"\n for d in Path(__file__).parents:\n if (d / \".git\").is_dir():\n return d", "chunk_type": "function", "name": "get_git_dir", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 936, "end_line": 945, "start_col": 0, "end_col": 20, "parent_name": null, "docstring": "Determine whether the current file is part of a git repository and if so, return the repository root directory.\n\nReturns:\n (Path | None): Git root directory if found or None if not found.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 3, "dependencies": [ "contextlib", "importlib.metadata", "inspect", "json", "logging", "os", "platform", "re", "subprocess", "sys", "threading", "time", "warnings", "pathlib.Path", "threading.Lock", "types.SimpleNamespace", "typing.Union", "urllib.parse.unquote", "cv2", "numpy", "torch", "tqdm", "ultralytics.__version__", "ultralytics.utils.patches.imread", "ultralytics.utils.patches.imshow", "ultralytics.utils.patches.imwrite", "ultralytics.utils.patches.torch_save", "tqdm.rich", "importlib.util", "pandas", "sqlite3", "functools.wraps", "yaml", "socket", "sentry_sdk", "hashlib", "uuid", "ultralytics.utils.torch_utils.torch_distributed_zero_first", "matplotlib.pyplot", "io" ], "chunk_id": "function_get_git_dir_e37bf49c" }, { "content": "def is_git_dir():\n \"\"\"\n Determine whether the current file is part of a git repository.\n\n Returns:\n (bool): True if current file is part of a git repository.\n \"\"\"\n return GIT_DIR is not None", "chunk_type": "function", "name": "is_git_dir", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 948, "end_line": 955, "start_col": 0, "end_col": 30, "parent_name": null, "docstring": "Determine whether the current file is part of a git repository.\n\nReturns:\n (bool): True if current file is part of a git repository.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "contextlib", "importlib.metadata", "inspect", "json", "logging", "os", "platform", "re", "subprocess", "sys", "threading", "time", "warnings", "pathlib.Path", "threading.Lock", "types.SimpleNamespace", "typing.Union", "urllib.parse.unquote", "cv2", "numpy", "torch", "tqdm", "ultralytics.__version__", "ultralytics.utils.patches.imread", "ultralytics.utils.patches.imshow", "ultralytics.utils.patches.imwrite", "ultralytics.utils.patches.torch_save", "tqdm.rich", "importlib.util", "pandas", "sqlite3", "functools.wraps", "yaml", "socket", "sentry_sdk", "hashlib", "uuid", "ultralytics.utils.torch_utils.torch_distributed_zero_first", "matplotlib.pyplot", "io" ], "chunk_id": "function_is_git_dir_b10e45b8" }, { "content": "def get_git_origin_url():\n \"\"\"\n Retrieve the origin URL of a git repository.\n\n Returns:\n (str | None): The origin URL of the git repository or None if not git directory.\n \"\"\"\n if IS_GIT_DIR:\n try:\n origin = subprocess.check_output([\"git\", \"config\", \"--get\", \"remote.origin.url\"])\n return origin.decode().strip()\n except subprocess.CalledProcessError:\n return None", "chunk_type": "function", "name": "get_git_origin_url", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 958, "end_line": 970, "start_col": 0, "end_col": 23, "parent_name": null, "docstring": "Retrieve the origin URL of a git repository.\n\nReturns:\n (str | None): The origin URL of the git repository or None if not git directory.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 3, "dependencies": [ "contextlib", "importlib.metadata", "inspect", "json", "logging", "os", "platform", "re", "subprocess", "sys", "threading", "time", "warnings", "pathlib.Path", "threading.Lock", "types.SimpleNamespace", "typing.Union", "urllib.parse.unquote", "cv2", "numpy", "torch", "tqdm", "ultralytics.__version__", "ultralytics.utils.patches.imread", "ultralytics.utils.patches.imshow", "ultralytics.utils.patches.imwrite", "ultralytics.utils.patches.torch_save", "tqdm.rich", "importlib.util", "pandas", "sqlite3", "functools.wraps", "yaml", "socket", "sentry_sdk", "hashlib", "uuid", "ultralytics.utils.torch_utils.torch_distributed_zero_first", "matplotlib.pyplot", "io" ], "chunk_id": "function_get_git_origin_url_75508ee5" }, { "content": "def get_git_branch():\n \"\"\"\n Return the current git branch name. If not in a git repository, return None.\n\n Returns:\n (str | None): The current git branch name or None if not a git directory.\n \"\"\"\n if IS_GIT_DIR:\n try:\n origin = subprocess.check_output([\"git\", \"rev-parse\", \"--abbrev-ref\", \"HEAD\"])\n return origin.decode().strip()\n except subprocess.CalledProcessError:\n return None", "chunk_type": "function", "name": "get_git_branch", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 973, "end_line": 985, "start_col": 0, "end_col": 23, "parent_name": null, "docstring": "Return the current git branch name. If not in a git repository, return None.\n\nReturns:\n (str | None): The current git branch name or None if not a git directory.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 3, "dependencies": [ "contextlib", "importlib.metadata", "inspect", "json", "logging", "os", "platform", "re", "subprocess", "sys", "threading", "time", "warnings", "pathlib.Path", "threading.Lock", "types.SimpleNamespace", "typing.Union", "urllib.parse.unquote", "cv2", "numpy", "torch", "tqdm", "ultralytics.__version__", "ultralytics.utils.patches.imread", "ultralytics.utils.patches.imshow", "ultralytics.utils.patches.imwrite", "ultralytics.utils.patches.torch_save", "tqdm.rich", "importlib.util", "pandas", "sqlite3", "functools.wraps", "yaml", "socket", "sentry_sdk", "hashlib", "uuid", "ultralytics.utils.torch_utils.torch_distributed_zero_first", "matplotlib.pyplot", "io" ], "chunk_id": "function_get_git_branch_63f076d6" }, { "content": "def get_default_args(func):\n \"\"\"\n Return a dictionary of default arguments for a function.\n\n Args:\n func (callable): The function to inspect.\n\n Returns:\n (dict): A dictionary where each key is a parameter name, and each value is the default value of that parameter.\n \"\"\"\n signature = inspect.signature(func)\n return {k: v.default for k, v in signature.parameters.items() if v.default is not inspect.Parameter.empty}", "chunk_type": "function", "name": "get_default_args", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 988, "end_line": 999, "start_col": 0, "end_col": 110, "parent_name": null, "docstring": "Return a dictionary of default arguments for a function.\n\nArgs:\n func (callable): The function to inspect.\n\nReturns:\n (dict): A dictionary where each key is a parameter name, and each value is the default value of that parameter.", "parameters": [ "func" ], "return_type": null, "decorators": [], "complexity_score": 2, "dependencies": [ "contextlib", "importlib.metadata", "inspect", "json", "logging", "os", "platform", "re", "subprocess", "sys", "threading", "time", "warnings", "pathlib.Path", "threading.Lock", "types.SimpleNamespace", "typing.Union", "urllib.parse.unquote", "cv2", "numpy", "torch", "tqdm", "ultralytics.__version__", "ultralytics.utils.patches.imread", "ultralytics.utils.patches.imshow", "ultralytics.utils.patches.imwrite", "ultralytics.utils.patches.torch_save", "tqdm.rich", "importlib.util", "pandas", "sqlite3", "functools.wraps", "yaml", "socket", "sentry_sdk", "hashlib", "uuid", "ultralytics.utils.torch_utils.torch_distributed_zero_first", "matplotlib.pyplot", "io" ], "chunk_id": "function_get_default_args_22819343" }, { "content": "def get_ubuntu_version():\n \"\"\"\n Retrieve the Ubuntu version if the OS is Ubuntu.\n\n Returns:\n (str): Ubuntu version or None if not an Ubuntu OS.\n \"\"\"\n if is_ubuntu():\n try:\n with open(\"/etc/os-release\") as f:\n return re.search(r'VERSION_ID=\"(\\d+\\.\\d+)\"', f.read())[1]\n except (FileNotFoundError, AttributeError):\n return None", "chunk_type": "function", "name": "get_ubuntu_version", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 1002, "end_line": 1014, "start_col": 0, "end_col": 23, "parent_name": null, "docstring": "Retrieve the Ubuntu version if the OS is Ubuntu.\n\nReturns:\n (str): Ubuntu version or None if not an Ubuntu OS.", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 3, "dependencies": [ "contextlib", "importlib.metadata", "inspect", "json", "logging", "os", "platform", "re", "subprocess", "sys", "threading", "time", "warnings", "pathlib.Path", "threading.Lock", "types.SimpleNamespace", "typing.Union", "urllib.parse.unquote", "cv2", "numpy", "torch", "tqdm", "ultralytics.__version__", "ultralytics.utils.patches.imread", "ultralytics.utils.patches.imshow", "ultralytics.utils.patches.imwrite", "ultralytics.utils.patches.torch_save", "tqdm.rich", "importlib.util", "pandas", "sqlite3", "functools.wraps", "yaml", "socket", "sentry_sdk", "hashlib", "uuid", "ultralytics.utils.torch_utils.torch_distributed_zero_first", "matplotlib.pyplot", "io" ], "chunk_id": "function_get_ubuntu_version_11ddee4f" }, { "content": "def get_user_config_dir(sub_dir=\"Ultralytics\"):\n \"\"\"\n Return the appropriate config directory based on the environment operating system.\n\n Args:\n sub_dir (str): The name of the subdirectory to create.\n\n Returns:\n (Path): The path to the user config directory.\n \"\"\"\n if WINDOWS:\n path = Path.home() / \"AppData\" / \"Roaming\" / sub_dir\n elif MACOS: # macOS\n path = Path.home() / \"Library\" / \"Application Support\" / sub_dir\n elif LINUX:\n path = Path.home() / \".config\" / sub_dir\n else:\n raise ValueError(f\"Unsupported operating system: {platform.system()}\")\n\n # GCP and AWS lambda fix, only /tmp is writeable\n if not is_dir_writeable(path.parent):\n LOGGER.warning(\n f\"user config directory '{path}' is not writeable, defaulting to '/tmp' or CWD.\"\n \"Alternatively you can define a YOLO_CONFIG_DIR environment variable for this path.\"\n )\n path = Path(\"/tmp\") / sub_dir if is_dir_writeable(\"/tmp\") else Path().cwd() / sub_dir\n\n # Create the subdirectory if it does not exist\n path.mkdir(parents=True, exist_ok=True)\n\n return path", "chunk_type": "function", "name": "get_user_config_dir", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 1017, "end_line": 1047, "start_col": 0, "end_col": 15, "parent_name": null, "docstring": "Return the appropriate config directory based on the environment operating system.\n\nArgs:\n sub_dir (str): The name of the subdirectory to create.\n\nReturns:\n (Path): The path to the user config directory.", "parameters": [ "sub_dir" ], "return_type": null, "decorators": [], "complexity_score": 5, "dependencies": [ "contextlib", "importlib.metadata", "inspect", "json", "logging", "os", "platform", "re", "subprocess", "sys", "threading", "time", "warnings", "pathlib.Path", "threading.Lock", "types.SimpleNamespace", "typing.Union", "urllib.parse.unquote", "cv2", "numpy", "torch", "tqdm", "ultralytics.__version__", "ultralytics.utils.patches.imread", "ultralytics.utils.patches.imshow", "ultralytics.utils.patches.imwrite", "ultralytics.utils.patches.torch_save", "tqdm.rich", "importlib.util", "pandas", "sqlite3", "functools.wraps", "yaml", "socket", "sentry_sdk", "hashlib", "uuid", "ultralytics.utils.torch_utils.torch_distributed_zero_first", "matplotlib.pyplot", "io" ], "chunk_id": "function_get_user_config_dir_f324c308" }, { "content": "DEVICE_MODEL = read_device_model() # is_jetson() and is_raspberrypi() depend on this constant", "chunk_type": "variable", "name": "DEVICE_MODEL", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 1051, "end_line": 1051, "start_col": 0, "end_col": 34, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_DEVICE_MODEL_c1957c1a" }, { "content": "ONLINE = is_online()", "chunk_type": "variable", "name": "ONLINE", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 1052, "end_line": 1052, "start_col": 0, "end_col": 20, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_ONLINE_bb1eb81e" }, { "content": "IS_COLAB = is_colab()", "chunk_type": "variable", "name": "IS_COLAB", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 1053, "end_line": 1053, "start_col": 0, "end_col": 21, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_IS_COLAB_fc125f59" }, { "content": "IS_KAGGLE = is_kaggle()", "chunk_type": "variable", "name": "IS_KAGGLE", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 1054, "end_line": 1054, "start_col": 0, "end_col": 23, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_IS_KAGGLE_9bd55599" }, { "content": "IS_DOCKER = is_docker()", "chunk_type": "variable", "name": "IS_DOCKER", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 1055, "end_line": 1055, "start_col": 0, "end_col": 23, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_IS_DOCKER_2f339b0c" }, { "content": "IS_JETSON = is_jetson()", "chunk_type": "variable", "name": "IS_JETSON", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 1056, "end_line": 1056, "start_col": 0, "end_col": 23, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_IS_JETSON_5d927d6c" }, { "content": "IS_JUPYTER = is_jupyter()", "chunk_type": "variable", "name": "IS_JUPYTER", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 1057, "end_line": 1057, "start_col": 0, "end_col": 25, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_IS_JUPYTER_563ad653" }, { "content": "IS_PIP_PACKAGE = is_pip_package()", "chunk_type": "variable", "name": "IS_PIP_PACKAGE", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 1058, "end_line": 1058, "start_col": 0, "end_col": 33, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_IS_PIP_PACKAGE_4deb7236" }, { "content": "IS_RASPBERRYPI = is_raspberrypi()", "chunk_type": "variable", "name": "IS_RASPBERRYPI", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 1059, "end_line": 1059, "start_col": 0, "end_col": 33, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_IS_RASPBERRYPI_1c4fce70" }, { "content": "GIT_DIR = get_git_dir()", "chunk_type": "variable", "name": "GIT_DIR", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 1060, "end_line": 1060, "start_col": 0, "end_col": 23, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_GIT_DIR_041c7857" }, { "content": "IS_GIT_DIR = is_git_dir()", "chunk_type": "variable", "name": "IS_GIT_DIR", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 1061, "end_line": 1061, "start_col": 0, "end_col": 25, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_IS_GIT_DIR_53796298" }, { "content": "USER_CONFIG_DIR = Path(os.getenv(\"YOLO_CONFIG_DIR\") or get_user_config_dir()) # Ultralytics settings dir", "chunk_type": "variable", "name": "USER_CONFIG_DIR", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 1062, "end_line": 1062, "start_col": 0, "end_col": 77, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_USER_CONFIG_DIR_fe8daf7c" }, { "content": "SETTINGS_FILE = USER_CONFIG_DIR / \"settings.json\"", "chunk_type": "variable", "name": "SETTINGS_FILE", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 1063, "end_line": 1063, "start_col": 0, "end_col": 49, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_SETTINGS_FILE_0070b1d4" }, { "content": "def colorstr(*input):\n r\"\"\"\n Color a string based on the provided color and style arguments using ANSI escape codes.\n\n This function can be called in two ways:\n - colorstr('color', 'style', 'your string')\n - colorstr('your string')\n\n In the second form, 'blue' and 'bold' will be applied by default.\n\n Args:\n *input (str | Path): A sequence of strings where the first n-1 strings are color and style arguments,\n and the last string is the one to be colored.\n\n Returns:\n (str): The input string wrapped with ANSI escape codes for the specified color and style.\n\n Notes:\n Supported Colors and Styles:\n - Basic Colors: 'black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'\n - Bright Colors: 'bright_black', 'bright_red', 'bright_green', 'bright_yellow',\n 'bright_blue', 'bright_magenta', 'bright_cyan', 'bright_white'\n - Misc: 'end', 'bold', 'underline'\n\n Examples:\n >>> colorstr(\"blue\", \"bold\", \"hello world\")\n >>> \"\\033[34m\\033[1mhello world\\033[0m\"\n\n References:\n https://en.wikipedia.org/wiki/ANSI_escape_code\n \"\"\"\n *args, string = input if len(input) > 1 else (\"blue\", \"bold\", input[0]) # color arguments, string\n colors = {\n \"black\": \"\\033[30m\", # basic colors\n \"red\": \"\\033[31m\",\n \"green\": \"\\033[32m\",\n \"yellow\": \"\\033[33m\",\n \"blue\": \"\\033[34m\",\n \"magenta\": \"\\033[35m\",\n \"cyan\": \"\\033[36m\",\n \"white\": \"\\033[37m\",\n \"bright_black\": \"\\033[90m\", # bright colors\n \"bright_red\": \"\\033[91m\",\n \"bright_green\": \"\\033[92m\",\n \"bright_yellow\": \"\\033[93m\",\n \"bright_blue\": \"\\033[94m\",\n \"bright_magenta\": \"\\033[95m\",\n \"bright_cyan\": \"\\033[96m\",\n \"bright_white\": \"\\033[97m\",\n \"end\": \"\\033[0m\", # misc\n \"bold\": \"\\033[1m\",\n \"underline\": \"\\033[4m\",\n }\n return \"\".join(colors[x] for x in args) + f\"{string}\" + colors[\"end\"]", "chunk_type": "function", "name": "colorstr", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 1066, "end_line": 1119, "start_col": 0, "end_col": 73, "parent_name": null, "docstring": "Color a string based on the provided color and style arguments using ANSI escape codes.\n\nThis function can be called in two ways:\n - colorstr('color', 'style', 'your string')\n - colorstr('your string')\n\nIn the second form, 'blue' and 'bold' will be applied by default.\n\nArgs:\n *input (str | Path): A sequence of strings where the first n-1 strings are color and style arguments,\n and the last string is the one to be colored.\n\nReturns:\n (str): The input string wrapped with ANSI escape codes for the specified color and style.\n\nNotes:\n Supported Colors and Styles:\n - Basic Colors: 'black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'\n - Bright Colors: 'bright_black', 'bright_red', 'bright_green', 'bright_yellow',\n 'bright_blue', 'bright_magenta', 'bright_cyan', 'bright_white'\n - Misc: 'end', 'bold', 'underline'\n\nExamples:\n >>> colorstr(\"blue\", \"bold\", \"hello world\")\n >>> \"\\033[34m\\033[1mhello world\\033[0m\"\n\nReferences:\n https://en.wikipedia.org/wiki/ANSI_escape_code", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 2, "dependencies": [ "contextlib", "importlib.metadata", "inspect", "json", "logging", "os", "platform", "re", "subprocess", "sys", "threading", "time", "warnings", "pathlib.Path", "threading.Lock", "types.SimpleNamespace", "typing.Union", "urllib.parse.unquote", "cv2", "numpy", "torch", "tqdm", "ultralytics.__version__", "ultralytics.utils.patches.imread", "ultralytics.utils.patches.imshow", "ultralytics.utils.patches.imwrite", "ultralytics.utils.patches.torch_save", "tqdm.rich", "importlib.util", "pandas", "sqlite3", "functools.wraps", "yaml", "socket", "sentry_sdk", "hashlib", "uuid", "ultralytics.utils.torch_utils.torch_distributed_zero_first", "matplotlib.pyplot", "io" ], "chunk_id": "function_colorstr_64f5b536" }, { "content": "def remove_colorstr(input_string):\n \"\"\"\n Remove ANSI escape codes from a string, effectively un-coloring it.\n\n Args:\n input_string (str): The string to remove color and style from.\n\n Returns:\n (str): A new string with all ANSI escape codes removed.\n\n Examples:\n >>> remove_colorstr(colorstr(\"blue\", \"bold\", \"hello world\"))\n >>> \"hello world\"\n \"\"\"\n ansi_escape = re.compile(r\"\\x1B\\[[0-9;]*[A-Za-z]\")\n return ansi_escape.sub(\"\", input_string)", "chunk_type": "function", "name": "remove_colorstr", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 1122, "end_line": 1137, "start_col": 0, "end_col": 44, "parent_name": null, "docstring": "Remove ANSI escape codes from a string, effectively un-coloring it.\n\nArgs:\n input_string (str): The string to remove color and style from.\n\nReturns:\n (str): A new string with all ANSI escape codes removed.\n\nExamples:\n >>> remove_colorstr(colorstr(\"blue\", \"bold\", \"hello world\"))\n >>> \"hello world\"", "parameters": [ "input_string" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "contextlib", "importlib.metadata", "inspect", "json", "logging", "os", "platform", "re", "subprocess", "sys", "threading", "time", "warnings", "pathlib.Path", "threading.Lock", "types.SimpleNamespace", "typing.Union", "urllib.parse.unquote", "cv2", "numpy", "torch", "tqdm", "ultralytics.__version__", "ultralytics.utils.patches.imread", "ultralytics.utils.patches.imshow", "ultralytics.utils.patches.imwrite", "ultralytics.utils.patches.torch_save", "tqdm.rich", "importlib.util", "pandas", "sqlite3", "functools.wraps", "yaml", "socket", "sentry_sdk", "hashlib", "uuid", "ultralytics.utils.torch_utils.torch_distributed_zero_first", "matplotlib.pyplot", "io" ], "chunk_id": "function_remove_colorstr_c31acf20" }, { "content": "class TryExcept(contextlib.ContextDecorator):\n \"\"\"\n Ultralytics TryExcept class for handling exceptions gracefully.\n\n This class can be used as a decorator or context manager to catch exceptions and optionally print warning messages.\n It allows code to continue execution even when exceptions occur, which is useful for non-critical operations.\n\n Attributes:\n msg (str): Optional message to display when an exception occurs.\n verbose (bool): Whether to print the exception message.\n\n Examples:\n As a decorator:\n >>> @TryExcept(msg=\"Error occurred in func\", verbose=True)\n >>> def func():\n >>> # Function logic here\n >>> pass\n\n As a context manager:\n >>> with TryExcept(msg=\"Error occurred in block\", verbose=True):\n >>> # Code block here\n >>> pass\n \"\"\"\n\n def __init__(self, msg=\"\", verbose=True):\n \"\"\"Initialize TryExcept class with optional message and verbosity settings.\"\"\"\n self.msg = msg\n self.verbose = verbose\n\n def __enter__(self):\n \"\"\"Execute when entering TryExcept context, initialize instance.\"\"\"\n pass\n\n def __exit__(self, exc_type, value, traceback):\n \"\"\"Define behavior when exiting a 'with' block, print error message if necessary.\"\"\"\n if self.verbose and value:\n LOGGER.warning(f\"{self.msg}{': ' if self.msg else ''}{value}\")\n return True", "chunk_type": "class", "name": "TryExcept", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 1140, "end_line": 1177, "start_col": 0, "end_col": 19, "parent_name": null, "docstring": "Ultralytics TryExcept class for handling exceptions gracefully.\n\nThis class can be used as a decorator or context manager to catch exceptions and optionally print warning messages.\nIt allows code to continue execution even when exceptions occur, which is useful for non-critical operations.\n\nAttributes:\n msg (str): Optional message to display when an exception occurs.\n verbose (bool): Whether to print the exception message.\n\nExamples:\n As a decorator:\n >>> @TryExcept(msg=\"Error occurred in func\", verbose=True)\n >>> def func():\n >>> # Function logic here\n >>> pass\n\n As a context manager:\n >>> with TryExcept(msg=\"Error occurred in block\", verbose=True):\n >>> # Code block here\n >>> pass", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "contextlib", "importlib.metadata", "inspect", "json", "logging", "os", "platform", "re", "subprocess", "sys", "threading", "time", "warnings", "pathlib.Path", "threading.Lock", "types.SimpleNamespace", "typing.Union", "urllib.parse.unquote", "cv2", "numpy", "torch", "tqdm", "ultralytics.__version__", "ultralytics.utils.patches.imread", "ultralytics.utils.patches.imshow", "ultralytics.utils.patches.imwrite", "ultralytics.utils.patches.torch_save", "tqdm.rich", "importlib.util", "pandas", "sqlite3", "functools.wraps", "yaml", "socket", "sentry_sdk", "hashlib", "uuid", "ultralytics.utils.torch_utils.torch_distributed_zero_first", "matplotlib.pyplot", "io", "contextlib.ContextDecorator" ], "chunk_id": "class_TryExcept_fe08fe13" }, { "content": "class Retry(contextlib.ContextDecorator):\n \"\"\"\n Retry class for function execution with exponential backoff.\n\n This decorator can be used to retry a function on exceptions, up to a specified number of times with an\n exponentially increasing delay between retries. It's useful for handling transient failures in network\n operations or other unreliable processes.\n\n Attributes:\n times (int): Maximum number of retry attempts.\n delay (int): Initial delay between retries in seconds.\n\n Examples:\n Example usage as a decorator:\n >>> @Retry(times=3, delay=2)\n >>> def test_func():\n >>> # Replace with function logic that may raise exceptions\n >>> return True\n \"\"\"\n\n def __init__(self, times=3, delay=2):\n \"\"\"Initialize Retry class with specified number of retries and delay.\"\"\"\n self.times = times\n self.delay = delay\n self._attempts = 0\n\n def __call__(self, func):\n \"\"\"Decorator implementation for Retry with exponential backoff.\"\"\"\n\n def wrapped_func(*args, **kwargs):\n \"\"\"Apply retries to the decorated function or method.\"\"\"\n self._attempts = 0\n while self._attempts < self.times:\n try:\n return func(*args, **kwargs)\n except Exception as e:\n self._attempts += 1\n LOGGER.warning(f\"Retry {self._attempts}/{self.times} failed: {e}\")\n if self._attempts >= self.times:\n raise e\n time.sleep(self.delay * (2**self._attempts)) # exponential backoff delay\n\n return wrapped_func", "chunk_type": "class", "name": "Retry", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 1180, "end_line": 1222, "start_col": 0, "end_col": 27, "parent_name": null, "docstring": "Retry class for function execution with exponential backoff.\n\nThis decorator can be used to retry a function on exceptions, up to a specified number of times with an\nexponentially increasing delay between retries. It's useful for handling transient failures in network\noperations or other unreliable processes.\n\nAttributes:\n times (int): Maximum number of retry attempts.\n delay (int): Initial delay between retries in seconds.\n\nExamples:\n Example usage as a decorator:\n >>> @Retry(times=3, delay=2)\n >>> def test_func():\n >>> # Replace with function logic that may raise exceptions\n >>> return True", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "contextlib", "importlib.metadata", "inspect", "json", "logging", "os", "platform", "re", "subprocess", "sys", "threading", "time", "warnings", "pathlib.Path", "threading.Lock", "types.SimpleNamespace", "typing.Union", "urllib.parse.unquote", "cv2", "numpy", "torch", "tqdm", "ultralytics.__version__", "ultralytics.utils.patches.imread", "ultralytics.utils.patches.imshow", "ultralytics.utils.patches.imwrite", "ultralytics.utils.patches.torch_save", "tqdm.rich", "importlib.util", "pandas", "sqlite3", "functools.wraps", "yaml", "socket", "sentry_sdk", "hashlib", "uuid", "ultralytics.utils.torch_utils.torch_distributed_zero_first", "matplotlib.pyplot", "io", "contextlib.ContextDecorator" ], "chunk_id": "class_Retry_abfaf7c4" }, { "content": "def threaded(func):\n \"\"\"\n Multi-thread a target function by default and return the thread or function result.\n\n This decorator provides flexible execution of the target function, either in a separate thread or synchronously.\n By default, the function runs in a thread, but this can be controlled via the 'threaded=False' keyword argument\n which is removed from kwargs before calling the function.\n\n Args:\n func (callable): The function to be potentially executed in a separate thread.\n\n Returns:\n (callable): A wrapper function that either returns a daemon thread or the direct function result.\n\n Examples:\n >>> @threaded\n ... def process_data(data):\n ... return data\n >>>\n >>> thread = process_data(my_data) # Runs in background thread\n >>> result = process_data(my_data, threaded=False) # Runs synchronously, returns function result\n \"\"\"\n\n def wrapper(*args, **kwargs):\n \"\"\"Multi-thread a given function based on 'threaded' kwarg and return the thread or function result.\"\"\"\n if kwargs.pop(\"threaded\", True): # run in thread\n thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)\n thread.start()\n return thread\n else:\n return func(*args, **kwargs)\n\n return wrapper", "chunk_type": "function", "name": "threaded", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 1225, "end_line": 1257, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": "Multi-thread a target function by default and return the thread or function result.\n\nThis decorator provides flexible execution of the target function, either in a separate thread or synchronously.\nBy default, the function runs in a thread, but this can be controlled via the 'threaded=False' keyword argument\nwhich is removed from kwargs before calling the function.\n\nArgs:\n func (callable): The function to be potentially executed in a separate thread.\n\nReturns:\n (callable): A wrapper function that either returns a daemon thread or the direct function result.\n\nExamples:\n >>> @threaded\n ... def process_data(data):\n ... return data\n >>>\n >>> thread = process_data(my_data) # Runs in background thread\n >>> result = process_data(my_data, threaded=False) # Runs synchronously, returns function result", "parameters": [ "func" ], "return_type": null, "decorators": [], "complexity_score": 2, "dependencies": [ "contextlib", "importlib.metadata", "inspect", "json", "logging", "os", "platform", "re", "subprocess", "sys", "threading", "time", "warnings", "pathlib.Path", "threading.Lock", "types.SimpleNamespace", "typing.Union", "urllib.parse.unquote", "cv2", "numpy", "torch", "tqdm", "ultralytics.__version__", "ultralytics.utils.patches.imread", "ultralytics.utils.patches.imshow", "ultralytics.utils.patches.imwrite", "ultralytics.utils.patches.torch_save", "tqdm.rich", "importlib.util", "pandas", "sqlite3", "functools.wraps", "yaml", "socket", "sentry_sdk", "hashlib", "uuid", "ultralytics.utils.torch_utils.torch_distributed_zero_first", "matplotlib.pyplot", "io" ], "chunk_id": "function_threaded_82043875" }, { "content": "def set_sentry():\n \"\"\"\n Initialize the Sentry SDK for error tracking and reporting.\n\n Only used if sentry_sdk package is installed and sync=True in settings. Run 'yolo settings' to see and update\n settings.\n\n Conditions required to send errors (ALL conditions must be met or no errors will be reported):\n - sentry_sdk package is installed\n - sync=True in YOLO settings\n - pytest is not running\n - running in a pip package installation\n - running in a non-git directory\n - running with rank -1 or 0\n - online environment\n - CLI used to run package (checked with 'yolo' as the name of the main CLI command)\n \"\"\"\n if (\n not SETTINGS[\"sync\"]\n or RANK not in {-1, 0}\n or Path(ARGV[0]).name != \"yolo\"\n or TESTS_RUNNING\n or not ONLINE\n or not IS_PIP_PACKAGE\n or IS_GIT_DIR\n ):\n return\n # If sentry_sdk package is not installed then return and do not use Sentry\n try:\n import sentry_sdk # noqa\n except ImportError:\n return\n\n def before_send(event, hint):\n \"\"\"\n Modify the event before sending it to Sentry based on specific exception types and messages.\n\n Args:\n event (dict): The event dictionary containing information about the error.\n hint (dict): A dictionary containing additional information about the error.\n\n Returns:\n (dict | None): The modified event or None if the event should not be sent to Sentry.\n \"\"\"\n if \"exc_info\" in hint:\n exc_type, exc_value, _ = hint[\"exc_info\"]\n if exc_type in {KeyboardInterrupt, FileNotFoundError} or \"out of memory\" in str(exc_value):\n return None # do not send event\n\n event[\"tags\"] = {\n \"sys_argv\": ARGV[0],\n \"sys_argv_name\": Path(ARGV[0]).name,\n \"install\": \"git\" if IS_GIT_DIR else \"pip\" if IS_PIP_PACKAGE else \"other\",\n \"os\": ENVIRONMENT,\n }\n return event\n\n sentry_sdk.init(\n dsn=\"https://888e5a0778212e1d0314c37d4b9aae5d@o4504521589325824.ingest.us.sentry.io/4504521592406016\",\n debug=False,\n auto_enabling_integrations=False,\n traces_sample_rate=1.0,\n release=__version__,\n environment=\"runpod\" if is_runpod() else \"production\",\n before_send=before_send,\n ignore_errors=[KeyboardInterrupt, FileNotFoundError],\n )\n sentry_sdk.set_user({\"id\": SETTINGS[\"uuid\"]}) # SHA-256 anonymized UUID hash", "chunk_type": "function", "name": "set_sentry", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 1260, "end_line": 1327, "start_col": 0, "end_col": 49, "parent_name": null, "docstring": "Initialize the Sentry SDK for error tracking and reporting.\n\nOnly used if sentry_sdk package is installed and sync=True in settings. Run 'yolo settings' to see and update\nsettings.\n\nConditions required to send errors (ALL conditions must be met or no errors will be reported):\n - sentry_sdk package is installed\n - sync=True in YOLO settings\n - pytest is not running\n - running in a pip package installation\n - running in a non-git directory\n - running with rank -1 or 0\n - online environment\n - CLI used to run package (checked with 'yolo' as the name of the main CLI command)", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 5, "dependencies": [ "contextlib", "importlib.metadata", "inspect", "json", "logging", "os", "platform", "re", "subprocess", "sys", "threading", "time", "warnings", "pathlib.Path", "threading.Lock", "types.SimpleNamespace", "typing.Union", "urllib.parse.unquote", "cv2", "numpy", "torch", "tqdm", "ultralytics.__version__", "ultralytics.utils.patches.imread", "ultralytics.utils.patches.imshow", "ultralytics.utils.patches.imwrite", "ultralytics.utils.patches.torch_save", "tqdm.rich", "importlib.util", "pandas", "sqlite3", "functools.wraps", "yaml", "socket", "sentry_sdk", "hashlib", "uuid", "ultralytics.utils.torch_utils.torch_distributed_zero_first", "matplotlib.pyplot", "io" ], "chunk_id": "function_set_sentry_9d4803f9" }, { "content": "class JSONDict(dict):\n \"\"\"\n A dictionary-like class that provides JSON persistence for its contents.\n\n This class extends the built-in dictionary to automatically save its contents to a JSON file whenever they are\n modified. It ensures thread-safe operations using a lock and handles JSON serialization of Path objects.\n\n Attributes:\n file_path (Path): The path to the JSON file used for persistence.\n lock (threading.Lock): A lock object to ensure thread-safe operations.\n\n Methods:\n _load: Load the data from the JSON file into the dictionary.\n _save: Save the current state of the dictionary to the JSON file.\n __setitem__: Store a key-value pair and persist it to disk.\n __delitem__: Remove an item and update the persistent storage.\n update: Update the dictionary and persist changes.\n clear: Clear all entries and update the persistent storage.\n\n Examples:\n >>> json_dict = JSONDict(\"data.json\")\n >>> json_dict[\"key\"] = \"value\"\n >>> print(json_dict[\"key\"])\n value\n >>> del json_dict[\"key\"]\n >>> json_dict.update({\"new_key\": \"new_value\"})\n >>> json_dict.clear()\n \"\"\"\n\n def __init__(self, file_path: Union[str, Path] = \"data.json\"):\n \"\"\"Initialize a JSONDict object with a specified file path for JSON persistence.\"\"\"\n super().__init__()\n self.file_path = Path(file_path)\n self.lock = Lock()\n self._load()\n\n def _load(self):\n \"\"\"Load the data from the JSON file into the dictionary.\"\"\"\n try:\n if self.file_path.exists():\n with open(self.file_path) as f:\n self.update(json.load(f))\n except json.JSONDecodeError:\n LOGGER.warning(f\"Error decoding JSON from {self.file_path}. Starting with an empty dictionary.\")\n except Exception as e:\n LOGGER.error(f\"Error reading from {self.file_path}: {e}\")\n\n def _save(self):\n \"\"\"Save the current state of the dictionary to the JSON file.\"\"\"\n try:\n self.file_path.parent.mkdir(parents=True, exist_ok=True)\n with open(self.file_path, \"w\", encoding=\"utf-8\") as f:\n json.dump(dict(self), f, indent=2, default=self._json_default)\n except Exception as e:\n LOGGER.error(f\"Error writing to {self.file_path}: {e}\")\n\n @staticmethod\n def _json_default(obj):\n \"\"\"Handle JSON serialization of Path objects.\"\"\"\n if isinstance(obj, Path):\n return str(obj)\n raise TypeError(f\"Object of type {type(obj).__name__} is not JSON serializable\")\n\n def __setitem__(self, key, value):\n \"\"\"Store a key-value pair and persist to disk.\"\"\"\n with self.lock:\n super().__setitem__(key, value)\n self._save()\n\n def __delitem__(self, key):\n \"\"\"Remove an item and update the persistent storage.\"\"\"\n with self.lock:\n super().__delitem__(key)\n self._save()\n\n def __str__(self):\n \"\"\"Return a pretty-printed JSON string representation of the dictionary.\"\"\"\n contents = json.dumps(dict(self), indent=2, ensure_ascii=False, default=self._json_default)\n return f'JSONDict(\"{self.file_path}\"):\\n{contents}'\n\n def update(self, *args, **kwargs):\n \"\"\"Update the dictionary and persist changes.\"\"\"\n with self.lock:\n super().update(*args, **kwargs)\n self._save()\n\n def clear(self):\n \"\"\"Clear all entries and update the persistent storage.\"\"\"\n with self.lock:\n super().clear()\n self._save()", "chunk_type": "class", "name": "JSONDict", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 1330, "end_line": 1420, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": "A dictionary-like class that provides JSON persistence for its contents.\n\nThis class extends the built-in dictionary to automatically save its contents to a JSON file whenever they are\nmodified. It ensures thread-safe operations using a lock and handles JSON serialization of Path objects.\n\nAttributes:\n file_path (Path): The path to the JSON file used for persistence.\n lock (threading.Lock): A lock object to ensure thread-safe operations.\n\nMethods:\n _load: Load the data from the JSON file into the dictionary.\n _save: Save the current state of the dictionary to the JSON file.\n __setitem__: Store a key-value pair and persist it to disk.\n __delitem__: Remove an item and update the persistent storage.\n update: Update the dictionary and persist changes.\n clear: Clear all entries and update the persistent storage.\n\nExamples:\n >>> json_dict = JSONDict(\"data.json\")\n >>> json_dict[\"key\"] = \"value\"\n >>> print(json_dict[\"key\"])\n value\n >>> del json_dict[\"key\"]\n >>> json_dict.update({\"new_key\": \"new_value\"})\n >>> json_dict.clear()", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "contextlib", "importlib.metadata", "inspect", "json", "logging", "os", "platform", "re", "subprocess", "sys", "threading", "time", "warnings", "pathlib.Path", "threading.Lock", "types.SimpleNamespace", "typing.Union", "urllib.parse.unquote", "cv2", "numpy", "torch", "tqdm", "ultralytics.__version__", "ultralytics.utils.patches.imread", "ultralytics.utils.patches.imshow", "ultralytics.utils.patches.imwrite", "ultralytics.utils.patches.torch_save", "tqdm.rich", "importlib.util", "pandas", "sqlite3", "functools.wraps", "yaml", "socket", "sentry_sdk", "hashlib", "uuid", "ultralytics.utils.torch_utils.torch_distributed_zero_first", "matplotlib.pyplot", "io", "dict" ], "chunk_id": "class_JSONDict_c1d7c6d5" }, { "content": "class SettingsManager(JSONDict):\n \"\"\"\n SettingsManager class for managing and persisting Ultralytics settings.\n\n This class extends JSONDict to provide JSON persistence for settings, ensuring thread-safe operations and default\n values. It validates settings on initialization and provides methods to update or reset settings. The settings\n include directories for datasets, weights, and runs, as well as various integration flags.\n\n Attributes:\n file (Path): The path to the JSON file used for persistence.\n version (str): The version of the settings schema.\n defaults (dict): A dictionary containing default settings.\n help_msg (str): A help message for users on how to view and update settings.\n\n Methods:\n _validate_settings: Validate the current settings and reset if necessary.\n update: Update settings, validating keys and types.\n reset: Reset the settings to default and save them.\n\n Examples:\n Initialize and update settings:\n >>> settings = SettingsManager()\n >>> settings.update(runs_dir=\"/new/runs/dir\")\n >>> print(settings[\"runs_dir\"])\n /new/runs/dir\n \"\"\"\n\n def __init__(self, file=SETTINGS_FILE, version=\"0.0.6\"):\n \"\"\"Initialize the SettingsManager with default settings and load user settings.\"\"\"\n import hashlib\n import uuid\n\n from ultralytics.utils.torch_utils import torch_distributed_zero_first\n\n root = GIT_DIR or Path()\n datasets_root = (root.parent if GIT_DIR and is_dir_writeable(root.parent) else root).resolve()\n\n self.file = Path(file)\n self.version = version\n self.defaults = {\n \"settings_version\": version, # Settings schema version\n \"datasets_dir\": str(datasets_root / \"datasets\"), # Datasets directory\n \"weights_dir\": str(root / \"weights\"), # Model weights directory\n \"runs_dir\": str(root / \"runs\"), # Experiment runs directory\n \"uuid\": hashlib.sha256(str(uuid.getnode()).encode()).hexdigest(), # SHA-256 anonymized UUID hash\n \"sync\": True, # Enable synchronization\n \"api_key\": \"\", # Ultralytics API Key\n \"openai_api_key\": \"\", # OpenAI API Key\n \"clearml\": True, # ClearML integration\n \"comet\": True, # Comet integration\n \"dvc\": True, # DVC integration\n \"hub\": True, # Ultralytics HUB integration\n \"mlflow\": True, # MLflow integration\n \"neptune\": True, # Neptune integration\n \"raytune\": True, # Ray Tune integration\n \"tensorboard\": False, # TensorBoard logging\n \"wandb\": False, # Weights & Biases logging\n \"vscode_msg\": True, # VSCode message\n \"openvino_msg\": True, # OpenVINO export on Intel CPU message\n }\n\n self.help_msg = (\n f\"\\nView Ultralytics Settings with 'yolo settings' or at '{self.file}'\"\n \"\\nUpdate Settings with 'yolo settings key=value', i.e. 'yolo settings runs_dir=path/to/dir'. \"\n \"For help see https://docs.ultralytics.com/quickstart/#ultralytics-settings.\"\n )\n\n with torch_distributed_zero_first(LOCAL_RANK):\n super().__init__(self.file)\n\n if not self.file.exists() or not self: # Check if file doesn't exist or is empty\n LOGGER.info(f\"Creating new Ultralytics Settings v{version} file ✅ {self.help_msg}\")\n self.reset()\n\n self._validate_settings()\n\n def _validate_settings(self):\n \"\"\"Validate the current settings and reset if necessary.\"\"\"\n correct_keys = frozenset(self.keys()) == frozenset(self.defaults.keys())\n correct_types = all(isinstance(self.get(k), type(v)) for k, v in self.defaults.items())\n correct_version = self.get(\"settings_version\", \"\") == self.version\n\n if not (correct_keys and correct_types and correct_version):\n LOGGER.warning(\n \"Ultralytics settings reset to default values. This may be due to a possible problem \"\n f\"with your settings or a recent ultralytics package update. {self.help_msg}\"\n )\n self.reset()\n\n if self.get(\"datasets_dir\") == self.get(\"runs_dir\"):\n LOGGER.warning(\n f\"Ultralytics setting 'datasets_dir: {self.get('datasets_dir')}' \"\n f\"must be different than 'runs_dir: {self.get('runs_dir')}'. \"\n f\"Please change one to avoid possible issues during training. {self.help_msg}\"\n )\n\n def __setitem__(self, key, value):\n \"\"\"Update one key: value pair.\"\"\"\n self.update({key: value})\n\n def update(self, *args, **kwargs):\n \"\"\"Update settings, validating keys and types.\"\"\"\n for arg in args:\n if isinstance(arg, dict):\n kwargs.update(arg)\n for k, v in kwargs.items():\n if k not in self.defaults:\n raise KeyError(f\"No Ultralytics setting '{k}'. {self.help_msg}\")\n t = type(self.defaults[k])\n if not isinstance(v, t):\n raise TypeError(\n f\"Ultralytics setting '{k}' must be '{t.__name__}' type, not '{type(v).__name__}'. {self.help_msg}\"\n )\n super().update(*args, **kwargs)\n\n def reset(self):\n \"\"\"Reset the settings to default and save them.\"\"\"\n self.clear()\n self.update(self.defaults)", "chunk_type": "class", "name": "SettingsManager", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 1423, "end_line": 1541, "start_col": 0, "end_col": 34, "parent_name": null, "docstring": "SettingsManager class for managing and persisting Ultralytics settings.\n\nThis class extends JSONDict to provide JSON persistence for settings, ensuring thread-safe operations and default\nvalues. It validates settings on initialization and provides methods to update or reset settings. The settings\ninclude directories for datasets, weights, and runs, as well as various integration flags.\n\nAttributes:\n file (Path): The path to the JSON file used for persistence.\n version (str): The version of the settings schema.\n defaults (dict): A dictionary containing default settings.\n help_msg (str): A help message for users on how to view and update settings.\n\nMethods:\n _validate_settings: Validate the current settings and reset if necessary.\n update: Update settings, validating keys and types.\n reset: Reset the settings to default and save them.\n\nExamples:\n Initialize and update settings:\n >>> settings = SettingsManager()\n >>> settings.update(runs_dir=\"/new/runs/dir\")\n >>> print(settings[\"runs_dir\"])\n /new/runs/dir", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "contextlib", "importlib.metadata", "inspect", "json", "logging", "os", "platform", "re", "subprocess", "sys", "threading", "time", "warnings", "pathlib.Path", "threading.Lock", "types.SimpleNamespace", "typing.Union", "urllib.parse.unquote", "cv2", "numpy", "torch", "tqdm", "ultralytics.__version__", "ultralytics.utils.patches.imread", "ultralytics.utils.patches.imshow", "ultralytics.utils.patches.imwrite", "ultralytics.utils.patches.torch_save", "tqdm.rich", "importlib.util", "pandas", "sqlite3", "functools.wraps", "yaml", "socket", "sentry_sdk", "hashlib", "uuid", "ultralytics.utils.torch_utils.torch_distributed_zero_first", "matplotlib.pyplot", "io", "JSONDict" ], "chunk_id": "class_SettingsManager_4b018472" }, { "content": "def deprecation_warn(arg, new_arg=None):\n \"\"\"Issue a deprecation warning when a deprecated argument is used, suggesting an updated argument.\"\"\"\n msg = f\"'{arg}' is deprecated and will be removed in in the future.\"\n if new_arg is not None:\n msg += f\" Use '{new_arg}' instead.\"\n LOGGER.warning(msg)", "chunk_type": "function", "name": "deprecation_warn", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 1544, "end_line": 1549, "start_col": 0, "end_col": 23, "parent_name": null, "docstring": "Issue a deprecation warning when a deprecated argument is used, suggesting an updated argument.", "parameters": [ "arg", "new_arg" ], "return_type": null, "decorators": [], "complexity_score": 2, "dependencies": [ "contextlib", "importlib.metadata", "inspect", "json", "logging", "os", "platform", "re", "subprocess", "sys", "threading", "time", "warnings", "pathlib.Path", "threading.Lock", "types.SimpleNamespace", "typing.Union", "urllib.parse.unquote", "cv2", "numpy", "torch", "tqdm", "ultralytics.__version__", "ultralytics.utils.patches.imread", "ultralytics.utils.patches.imshow", "ultralytics.utils.patches.imwrite", "ultralytics.utils.patches.torch_save", "tqdm.rich", "importlib.util", "pandas", "sqlite3", "functools.wraps", "yaml", "socket", "sentry_sdk", "hashlib", "uuid", "ultralytics.utils.torch_utils.torch_distributed_zero_first", "matplotlib.pyplot", "io" ], "chunk_id": "function_deprecation_warn_6d27243b" }, { "content": "def clean_url(url):\n \"\"\"Strip auth from URL, i.e. https://url.com/file.txt?auth -> https://url.com/file.txt.\"\"\"\n url = Path(url).as_posix().replace(\":/\", \"://\") # Pathlib turns :// -> :/, as_posix() for Windows\n return unquote(url).split(\"?\", 1)[0] # '%2F' to '/', split https://url.com/file.txt?auth", "chunk_type": "function", "name": "clean_url", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 1552, "end_line": 1555, "start_col": 0, "end_col": 40, "parent_name": null, "docstring": "Strip auth from URL, i.e. https://url.com/file.txt?auth -> https://url.com/file.txt.", "parameters": [ "url" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "contextlib", "importlib.metadata", "inspect", "json", "logging", "os", "platform", "re", "subprocess", "sys", "threading", "time", "warnings", "pathlib.Path", "threading.Lock", "types.SimpleNamespace", "typing.Union", "urllib.parse.unquote", "cv2", "numpy", "torch", "tqdm", "ultralytics.__version__", "ultralytics.utils.patches.imread", "ultralytics.utils.patches.imshow", "ultralytics.utils.patches.imwrite", "ultralytics.utils.patches.torch_save", "tqdm.rich", "importlib.util", "pandas", "sqlite3", "functools.wraps", "yaml", "socket", "sentry_sdk", "hashlib", "uuid", "ultralytics.utils.torch_utils.torch_distributed_zero_first", "matplotlib.pyplot", "io" ], "chunk_id": "function_clean_url_c74078e5" }, { "content": "def url2file(url):\n \"\"\"Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt.\"\"\"\n return Path(clean_url(url)).name", "chunk_type": "function", "name": "url2file", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 1558, "end_line": 1560, "start_col": 0, "end_col": 36, "parent_name": null, "docstring": "Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt.", "parameters": [ "url" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "contextlib", "importlib.metadata", "inspect", "json", "logging", "os", "platform", "re", "subprocess", "sys", "threading", "time", "warnings", "pathlib.Path", "threading.Lock", "types.SimpleNamespace", "typing.Union", "urllib.parse.unquote", "cv2", "numpy", "torch", "tqdm", "ultralytics.__version__", "ultralytics.utils.patches.imread", "ultralytics.utils.patches.imshow", "ultralytics.utils.patches.imwrite", "ultralytics.utils.patches.torch_save", "tqdm.rich", "importlib.util", "pandas", "sqlite3", "functools.wraps", "yaml", "socket", "sentry_sdk", "hashlib", "uuid", "ultralytics.utils.torch_utils.torch_distributed_zero_first", "matplotlib.pyplot", "io" ], "chunk_id": "function_url2file_fafbdcda" }, { "content": "def vscode_msg(ext=\"ultralytics.ultralytics-snippets\") -> str:\n \"\"\"Display a message to install Ultralytics-Snippets for VS Code if not already installed.\"\"\"\n path = (USER_CONFIG_DIR.parents[2] if WINDOWS else USER_CONFIG_DIR.parents[1]) / \".vscode/extensions\"\n obs_file = path / \".obsolete\" # file tracks uninstalled extensions, while source directory remains\n installed = any(path.glob(f\"{ext}*\")) and ext not in (obs_file.read_text(\"utf-8\") if obs_file.exists() else \"\")\n url = \"https://docs.ultralytics.com/integrations/vscode\"\n return \"\" if installed else f\"{colorstr('VS Code:')} view Ultralytics VS Code Extension ⚡ at {url}\"", "chunk_type": "function", "name": "vscode_msg", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 1563, "end_line": 1569, "start_col": 0, "end_col": 105, "parent_name": null, "docstring": "Display a message to install Ultralytics-Snippets for VS Code if not already installed.", "parameters": [ "ext" ], "return_type": "str", "decorators": [], "complexity_score": 1, "dependencies": [ "contextlib", "importlib.metadata", "inspect", "json", "logging", "os", "platform", "re", "subprocess", "sys", "threading", "time", "warnings", "pathlib.Path", "threading.Lock", "types.SimpleNamespace", "typing.Union", "urllib.parse.unquote", "cv2", "numpy", "torch", "tqdm", "ultralytics.__version__", "ultralytics.utils.patches.imread", "ultralytics.utils.patches.imshow", "ultralytics.utils.patches.imwrite", "ultralytics.utils.patches.torch_save", "tqdm.rich", "importlib.util", "pandas", "sqlite3", "functools.wraps", "yaml", "socket", "sentry_sdk", "hashlib", "uuid", "ultralytics.utils.torch_utils.torch_distributed_zero_first", "matplotlib.pyplot", "io" ], "chunk_id": "function_vscode_msg_a4abd200" }, { "content": "PREFIX = colorstr(\"Ultralytics: \")", "chunk_type": "variable", "name": "PREFIX", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 1575, "end_line": 1575, "start_col": 0, "end_col": 34, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_PREFIX_edda2740" }, { "content": "SETTINGS = SettingsManager() # initialize settings", "chunk_type": "variable", "name": "SETTINGS", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 1576, "end_line": 1576, "start_col": 0, "end_col": 28, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_SETTINGS_50e0f203" }, { "content": "PERSISTENT_CACHE = JSONDict(USER_CONFIG_DIR / \"persistent_cache.json\") # initialize persistent cache", "chunk_type": "variable", "name": "PERSISTENT_CACHE", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 1577, "end_line": 1577, "start_col": 0, "end_col": 70, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_PERSISTENT_CACHE_48cfe187" }, { "content": "DATASETS_DIR = Path(SETTINGS[\"datasets_dir\"]) # global datasets directory", "chunk_type": "variable", "name": "DATASETS_DIR", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 1578, "end_line": 1578, "start_col": 0, "end_col": 45, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_DATASETS_DIR_c43da5f2" }, { "content": "WEIGHTS_DIR = Path(SETTINGS[\"weights_dir\"]) # global weights directory", "chunk_type": "variable", "name": "WEIGHTS_DIR", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 1579, "end_line": 1579, "start_col": 0, "end_col": 43, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_WEIGHTS_DIR_f0a0ecc2" }, { "content": "RUNS_DIR = Path(SETTINGS[\"runs_dir\"]) # global runs directory", "chunk_type": "variable", "name": "RUNS_DIR", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 1580, "end_line": 1580, "start_col": 0, "end_col": 37, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_RUNS_DIR_fe066ddb" }, { "content": "ENVIRONMENT = (\n \"Colab\"\n if IS_COLAB\n else \"Kaggle\"\n if IS_KAGGLE\n else \"Jupyter\"\n if IS_JUPYTER\n else \"Docker\"\n if IS_DOCKER\n else platform.system()\n)", "chunk_type": "variable", "name": "ENVIRONMENT", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 1581, "end_line": 1591, "start_col": 0, "end_col": 1, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_ENVIRONMENT_ff00e629" }, { "content": "TESTS_RUNNING = is_pytest_running() or is_github_action_running()", "chunk_type": "variable", "name": "TESTS_RUNNING", "file_path": "ultralytics\\ultralytics\\utils\\__init__.py", "start_line": 1592, "end_line": 1592, "start_col": 0, "end_col": 65, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_TESTS_RUNNING_2f8be4b0" }, { "content": "import concurrent.futures", "chunk_type": "import", "name": "concurrent.futures", "file_path": "ultralytics\\ultralytics\\hub\\google\\__init__.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 25, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_concurrent.futures_b27c7ccd" }, { "content": "import statistics", "chunk_type": "import", "name": "statistics", "file_path": "ultralytics\\ultralytics\\hub\\google\\__init__.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 17, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_statistics_2b7efb0d" }, { "content": "import time", "chunk_type": "import", "name": "time", "file_path": "ultralytics\\ultralytics\\hub\\google\\__init__.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 11, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_time_cacbb5ec" }, { "content": "from typing import List, Optional, Tuple", "chunk_type": "import", "name": "List, Optional, Tuple", "file_path": "ultralytics\\ultralytics\\hub\\google\\__init__.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 40, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_List, Optional, Tuple_ceea995b" }, { "content": "import requests", "chunk_type": "import", "name": "requests", "file_path": "ultralytics\\ultralytics\\hub\\google\\__init__.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 15, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_requests_c171e7cd" }, { "content": "class GCPRegions:\n \"\"\"\n A class for managing and analyzing Google Cloud Platform (GCP) regions.\n\n This class provides functionality to initialize, categorize, and analyze GCP regions based on their\n geographical location, tier classification, and network latency.\n\n Attributes:\n regions (Dict[str, Tuple[int, str, str]]): A dictionary of GCP regions with their tier, city, and country.\n\n Methods:\n tier1: Returns a list of tier 1 GCP regions.\n tier2: Returns a list of tier 2 GCP regions.\n lowest_latency: Determines the GCP region(s) with the lowest network latency.\n\n Examples:\n >>> from ultralytics.hub.google import GCPRegions\n >>> regions = GCPRegions()\n >>> lowest_latency_region = regions.lowest_latency(verbose=True, attempts=3)\n >>> print(f\"Lowest latency region: {lowest_latency_region[0][0]}\")\n \"\"\"\n\n def __init__(self):\n \"\"\"Initialize the GCPRegions class with predefined Google Cloud Platform regions and their details.\"\"\"\n self.regions = {\n \"asia-east1\": (1, \"Taiwan\", \"China\"),\n \"asia-east2\": (2, \"Hong Kong\", \"China\"),\n \"asia-northeast1\": (1, \"Tokyo\", \"Japan\"),\n \"asia-northeast2\": (1, \"Osaka\", \"Japan\"),\n \"asia-northeast3\": (2, \"Seoul\", \"South Korea\"),\n \"asia-south1\": (2, \"Mumbai\", \"India\"),\n \"asia-south2\": (2, \"Delhi\", \"India\"),\n \"asia-southeast1\": (2, \"Jurong West\", \"Singapore\"),\n \"asia-southeast2\": (2, \"Jakarta\", \"Indonesia\"),\n \"australia-southeast1\": (2, \"Sydney\", \"Australia\"),\n \"australia-southeast2\": (2, \"Melbourne\", \"Australia\"),\n \"europe-central2\": (2, \"Warsaw\", \"Poland\"),\n \"europe-north1\": (1, \"Hamina\", \"Finland\"),\n \"europe-southwest1\": (1, \"Madrid\", \"Spain\"),\n \"europe-west1\": (1, \"St. Ghislain\", \"Belgium\"),\n \"europe-west10\": (2, \"Berlin\", \"Germany\"),\n \"europe-west12\": (2, \"Turin\", \"Italy\"),\n \"europe-west2\": (2, \"London\", \"United Kingdom\"),\n \"europe-west3\": (2, \"Frankfurt\", \"Germany\"),\n \"europe-west4\": (1, \"Eemshaven\", \"Netherlands\"),\n \"europe-west6\": (2, \"Zurich\", \"Switzerland\"),\n \"europe-west8\": (1, \"Milan\", \"Italy\"),\n \"europe-west9\": (1, \"Paris\", \"France\"),\n \"me-central1\": (2, \"Doha\", \"Qatar\"),\n \"me-west1\": (1, \"Tel Aviv\", \"Israel\"),\n \"northamerica-northeast1\": (2, \"Montreal\", \"Canada\"),\n \"northamerica-northeast2\": (2, \"Toronto\", \"Canada\"),\n \"southamerica-east1\": (2, \"São Paulo\", \"Brazil\"),\n \"southamerica-west1\": (2, \"Santiago\", \"Chile\"),\n \"us-central1\": (1, \"Iowa\", \"United States\"),\n \"us-east1\": (1, \"South Carolina\", \"United States\"),\n \"us-east4\": (1, \"Northern Virginia\", \"United States\"),\n \"us-east5\": (1, \"Columbus\", \"United States\"),\n \"us-south1\": (1, \"Dallas\", \"United States\"),\n \"us-west1\": (1, \"Oregon\", \"United States\"),\n \"us-west2\": (2, \"Los Angeles\", \"United States\"),\n \"us-west3\": (2, \"Salt Lake City\", \"United States\"),\n \"us-west4\": (2, \"Las Vegas\", \"United States\"),\n }\n\n def tier1(self) -> List[str]:\n \"\"\"Return a list of GCP regions classified as tier 1 based on predefined criteria.\"\"\"\n return [region for region, info in self.regions.items() if info[0] == 1]\n\n def tier2(self) -> List[str]:\n \"\"\"Return a list of GCP regions classified as tier 2 based on predefined criteria.\"\"\"\n return [region for region, info in self.regions.items() if info[0] == 2]\n\n @staticmethod\n def _ping_region(region: str, attempts: int = 1) -> Tuple[str, float, float, float, float]:\n \"\"\"\n Ping a specified GCP region and measure network latency statistics.\n\n Args:\n region (str): The GCP region identifier to ping (e.g., 'us-central1').\n attempts (int, optional): Number of ping attempts to make for calculating statistics.\n\n Returns:\n region (str): The GCP region identifier that was pinged.\n mean_latency (float): Mean latency in milliseconds, or infinity if all pings failed.\n std_dev (float): Standard deviation of latencies in milliseconds, or infinity if all pings failed.\n min_latency (float): Minimum latency in milliseconds, or infinity if all pings failed.\n max_latency (float): Maximum latency in milliseconds, or infinity if all pings failed.\n\n Examples:\n >>> region, mean, std, min_lat, max_lat = GCPRegions._ping_region(\"us-central1\", attempts=3)\n >>> print(f\"Region {region} has mean latency: {mean:.2f}ms\")\n \"\"\"\n url = f\"https://{region}-docker.pkg.dev\"\n latencies = []\n for _ in range(attempts):\n try:\n start_time = time.time()\n _ = requests.head(url, timeout=5)\n latency = (time.time() - start_time) * 1000 # Convert latency to milliseconds\n if latency != float(\"inf\"):\n latencies.append(latency)\n except requests.RequestException:\n pass\n if not latencies:\n return region, float(\"inf\"), float(\"inf\"), float(\"inf\"), float(\"inf\")\n\n std_dev = statistics.stdev(latencies) if len(latencies) > 1 else 0\n return region, statistics.mean(latencies), std_dev, min(latencies), max(latencies)\n\n def lowest_latency(\n self,\n top: int = 1,\n verbose: bool = False,\n tier: Optional[int] = None,\n attempts: int = 1,\n ) -> List[Tuple[str, float, float, float, float]]:\n \"\"\"\n Determine the GCP regions with the lowest latency based on ping tests.\n\n Args:\n top (int, optional): Number of top regions to return.\n verbose (bool, optional): If True, prints detailed latency information for all tested regions.\n tier (int | None, optional): Filter regions by tier (1 or 2). If None, all regions are tested.\n attempts (int, optional): Number of ping attempts per region.\n\n Returns:\n (List[Tuple[str, float, float, float, float]]): List of tuples containing region information and\n latency statistics. Each tuple contains (region, mean_latency, std_dev, min_latency, max_latency).\n\n Examples:\n >>> regions = GCPRegions()\n >>> results = regions.lowest_latency(top=3, verbose=True, tier=1, attempts=2)\n >>> print(results[0][0]) # Print the name of the lowest latency region\n \"\"\"\n if verbose:\n print(f\"Testing GCP regions for latency (with {attempts} {'retry' if attempts == 1 else 'attempts'})...\")\n\n regions_to_test = [k for k, v in self.regions.items() if v[0] == tier] if tier else list(self.regions.keys())\n with concurrent.futures.ThreadPoolExecutor(max_workers=50) as executor:\n results = list(executor.map(lambda r: self._ping_region(r, attempts), regions_to_test))\n\n sorted_results = sorted(results, key=lambda x: x[1])\n\n if verbose:\n print(f\"{'Region':<25} {'Location':<35} {'Tier':<5} Latency (ms)\")\n for region, mean, std, min_, max_ in sorted_results:\n tier, city, country = self.regions[region]\n location = f\"{city}, {country}\"\n if mean == float(\"inf\"):\n print(f\"{region:<25} {location:<35} {tier:<5} Timeout\")\n else:\n print(f\"{region:<25} {location:<35} {tier:<5} {mean:.0f} ± {std:.0f} ({min_:.0f} - {max_:.0f})\")\n print(f\"\\nLowest latency region{'s' if top > 1 else ''}:\")\n for region, mean, std, min_, max_ in sorted_results[:top]:\n tier, city, country = self.regions[region]\n location = f\"{city}, {country}\"\n print(f\"{region} ({location}, {mean:.0f} ± {std:.0f} ms ({min_:.0f} - {max_:.0f}))\")\n\n return sorted_results[:top]", "chunk_type": "class", "name": "GCPRegions", "file_path": "ultralytics\\ultralytics\\hub\\google\\__init__.py", "start_line": 11, "end_line": 170, "start_col": 0, "end_col": 35, "parent_name": null, "docstring": "A class for managing and analyzing Google Cloud Platform (GCP) regions.\n\nThis class provides functionality to initialize, categorize, and analyze GCP regions based on their\ngeographical location, tier classification, and network latency.\n\nAttributes:\n regions (Dict[str, Tuple[int, str, str]]): A dictionary of GCP regions with their tier, city, and country.\n\nMethods:\n tier1: Returns a list of tier 1 GCP regions.\n tier2: Returns a list of tier 2 GCP regions.\n lowest_latency: Determines the GCP region(s) with the lowest network latency.\n\nExamples:\n >>> from ultralytics.hub.google import GCPRegions\n >>> regions = GCPRegions()\n >>> lowest_latency_region = regions.lowest_latency(verbose=True, attempts=3)\n >>> print(f\"Lowest latency region: {lowest_latency_region[0][0]}\")", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "concurrent.futures", "statistics", "time", "typing.List", "typing.Optional", "typing.Tuple", "requests" ], "chunk_id": "class_GCPRegions_792880b7" }, { "content": "from pathlib import Path", "chunk_type": "import", "name": "Path", "file_path": "ultralytics\\ultralytics\\models\\fastsam\\model.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Path_88442790" }, { "content": "from typing import Any, Dict, List, Optional", "chunk_type": "import", "name": "Any, Dict, List, Optional", "file_path": "ultralytics\\ultralytics\\models\\fastsam\\model.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 44, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Any, Dict, List, Optional_ffa21a54" }, { "content": "from ultralytics.engine.model import Model", "chunk_type": "import", "name": "Model", "file_path": "ultralytics\\ultralytics\\models\\fastsam\\model.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 42, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Model_932838d8" }, { "content": "from .predict import FastSAMPredictor", "chunk_type": "import", "name": "FastSAMPredictor", "file_path": "ultralytics\\ultralytics\\models\\fastsam\\model.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 37, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_FastSAMPredictor_26ba58f9" }, { "content": "from .val import FastSAMValidator", "chunk_type": "import", "name": "FastSAMValidator", "file_path": "ultralytics\\ultralytics\\models\\fastsam\\model.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 33, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_FastSAMValidator_ab2f29b8" }, { "content": "class FastSAM(Model):\n \"\"\"\n FastSAM model interface for segment anything tasks.\n\n This class extends the base Model class to provide specific functionality for the FastSAM (Fast Segment Anything\n Model) implementation, allowing for efficient and accurate image segmentation with optional prompting support.\n\n Attributes:\n model (str): Path to the pre-trained FastSAM model file.\n task (str): The task type, set to \"segment\" for FastSAM models.\n\n Methods:\n predict: Perform segmentation prediction on image or video source with optional prompts.\n task_map: Returns mapping of segment task to predictor and validator classes.\n\n Examples:\n Initialize FastSAM model and run prediction\n >>> from ultralytics import FastSAM\n >>> model = FastSAM(\"FastSAM-x.pt\")\n >>> results = model.predict(\"ultralytics/assets/bus.jpg\")\n\n Run prediction with bounding box prompts\n >>> results = model.predict(\"image.jpg\", bboxes=[[100, 100, 200, 200]])\n \"\"\"\n\n def __init__(self, model: str = \"FastSAM-x.pt\"):\n \"\"\"Initialize the FastSAM model with the specified pre-trained weights.\"\"\"\n if str(model) == \"FastSAM.pt\":\n model = \"FastSAM-x.pt\"\n assert Path(model).suffix not in {\".yaml\", \".yml\"}, \"FastSAM models only support pre-trained models.\"\n super().__init__(model=model, task=\"segment\")\n\n def predict(\n self,\n source,\n stream: bool = False,\n bboxes: Optional[List] = None,\n points: Optional[List] = None,\n labels: Optional[List] = None,\n texts: Optional[List] = None,\n **kwargs: Any,\n ):\n \"\"\"\n Perform segmentation prediction on image or video source.\n\n Supports prompted segmentation with bounding boxes, points, labels, and texts. The method packages these\n prompts and passes them to the parent class predict method for processing.\n\n Args:\n source (str | PIL.Image | np.ndarray): Input source for prediction, can be a file path, URL, PIL image,\n or numpy array.\n stream (bool): Whether to enable real-time streaming mode for video inputs.\n bboxes (List, optional): Bounding box coordinates for prompted segmentation in format [[x1, y1, x2, y2]].\n points (List, optional): Point coordinates for prompted segmentation in format [[x, y]].\n labels (List, optional): Class labels for prompted segmentation.\n texts (List, optional): Text prompts for segmentation guidance.\n **kwargs (Any): Additional keyword arguments passed to the predictor.\n\n Returns:\n (List): List of Results objects containing the prediction results.\n \"\"\"\n prompts = dict(bboxes=bboxes, points=points, labels=labels, texts=texts)\n return super().predict(source, stream, prompts=prompts, **kwargs)\n\n @property\n def task_map(self) -> Dict[str, Dict[str, Any]]:\n \"\"\"Returns a dictionary mapping segment task to corresponding predictor and validator classes.\"\"\"\n return {\"segment\": {\"predictor\": FastSAMPredictor, \"validator\": FastSAMValidator}}", "chunk_type": "class", "name": "FastSAM", "file_path": "ultralytics\\ultralytics\\models\\fastsam\\model.py", "start_line": 12, "end_line": 79, "start_col": 0, "end_col": 90, "parent_name": null, "docstring": "FastSAM model interface for segment anything tasks.\n\nThis class extends the base Model class to provide specific functionality for the FastSAM (Fast Segment Anything\nModel) implementation, allowing for efficient and accurate image segmentation with optional prompting support.\n\nAttributes:\n model (str): Path to the pre-trained FastSAM model file.\n task (str): The task type, set to \"segment\" for FastSAM models.\n\nMethods:\n predict: Perform segmentation prediction on image or video source with optional prompts.\n task_map: Returns mapping of segment task to predictor and validator classes.\n\nExamples:\n Initialize FastSAM model and run prediction\n >>> from ultralytics import FastSAM\n >>> model = FastSAM(\"FastSAM-x.pt\")\n >>> results = model.predict(\"ultralytics/assets/bus.jpg\")\n\n Run prediction with bounding box prompts\n >>> results = model.predict(\"image.jpg\", bboxes=[[100, 100, 200, 200]])", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "pathlib.Path", "typing.Any", "typing.Dict", "typing.List", "typing.Optional", "ultralytics.engine.model.Model", "predict.FastSAMPredictor", "val.FastSAMValidator", "Model" ], "chunk_id": "class_FastSAM_eabbd80a" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\ultralytics\\models\\fastsam\\predict.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_36591e18" }, { "content": "from PIL import Image", "chunk_type": "import", "name": "Image", "file_path": "ultralytics\\ultralytics\\models\\fastsam\\predict.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 21, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Image_40761ca2" }, { "content": "from ultralytics.models.yolo.segment import SegmentationPredictor", "chunk_type": "import", "name": "SegmentationPredictor", "file_path": "ultralytics\\ultralytics\\models\\fastsam\\predict.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 65, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_SegmentationPredictor_1b76dc07" }, { "content": "from ultralytics.utils import DEFAULT_CFG, checks", "chunk_type": "import", "name": "DEFAULT_CFG, checks", "file_path": "ultralytics\\ultralytics\\models\\fastsam\\predict.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 49, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_DEFAULT_CFG, checks_7e1856ec" }, { "content": "from ultralytics.utils.metrics import box_iou", "chunk_type": "import", "name": "box_iou", "file_path": "ultralytics\\ultralytics\\models\\fastsam\\predict.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 45, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_box_iou_3c40557b" }, { "content": "from ultralytics.utils.ops import scale_masks", "chunk_type": "import", "name": "scale_masks", "file_path": "ultralytics\\ultralytics\\models\\fastsam\\predict.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 45, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_scale_masks_73ead197" }, { "content": "from .utils import adjust_bboxes_to_image_border", "chunk_type": "import", "name": "adjust_bboxes_to_image_border", "file_path": "ultralytics\\ultralytics\\models\\fastsam\\predict.py", "start_line": 11, "end_line": 11, "start_col": 0, "end_col": 48, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_adjust_bboxes_to_image_border_8833e2ef" }, { "content": "class FastSAMPredictor(SegmentationPredictor):\n \"\"\"\n FastSAMPredictor is specialized for fast SAM (Segment Anything Model) segmentation prediction tasks.\n\n This class extends the SegmentationPredictor, customizing the prediction pipeline specifically for fast SAM. It\n adjusts post-processing steps to incorporate mask prediction and non-maximum suppression while optimizing for\n single-class segmentation.\n\n Attributes:\n prompts (dict): Dictionary containing prompt information for segmentation (bboxes, points, labels, texts).\n device (torch.device): Device on which model and tensors are processed.\n clip_model (Any, optional): CLIP model for text-based prompting, loaded on demand.\n clip_preprocess (Any, optional): CLIP preprocessing function for images, loaded on demand.\n\n Methods:\n postprocess: Apply postprocessing to FastSAM predictions and handle prompts.\n prompt: Perform image segmentation inference based on various prompt types.\n set_prompts: Set prompts to be used during inference.\n \"\"\"\n\n def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):\n \"\"\"\n Initialize the FastSAMPredictor with configuration and callbacks.\n\n This initializes a predictor specialized for Fast SAM (Segment Anything Model) segmentation tasks. The predictor\n extends SegmentationPredictor with custom post-processing for mask prediction and non-maximum suppression\n optimized for single-class segmentation.\n\n Args:\n cfg (dict): Configuration for the predictor.\n overrides (dict, optional): Configuration overrides.\n _callbacks (list, optional): List of callback functions.\n \"\"\"\n super().__init__(cfg, overrides, _callbacks)\n self.prompts = {}\n\n def postprocess(self, preds, img, orig_imgs):\n \"\"\"\n Apply postprocessing to FastSAM predictions and handle prompts.\n\n Args:\n preds (List[torch.Tensor]): Raw predictions from the model.\n img (torch.Tensor): Input image tensor that was fed to the model.\n orig_imgs (List[np.ndarray]): Original images before preprocessing.\n\n Returns:\n (List[Results]): Processed results with prompts applied.\n \"\"\"\n bboxes = self.prompts.pop(\"bboxes\", None)\n points = self.prompts.pop(\"points\", None)\n labels = self.prompts.pop(\"labels\", None)\n texts = self.prompts.pop(\"texts\", None)\n results = super().postprocess(preds, img, orig_imgs)\n for result in results:\n full_box = torch.tensor(\n [0, 0, result.orig_shape[1], result.orig_shape[0]], device=preds[0].device, dtype=torch.float32\n )\n boxes = adjust_bboxes_to_image_border(result.boxes.xyxy, result.orig_shape)\n idx = torch.nonzero(box_iou(full_box[None], boxes) > 0.9).flatten()\n if idx.numel() != 0:\n result.boxes.xyxy[idx] = full_box\n\n return self.prompt(results, bboxes=bboxes, points=points, labels=labels, texts=texts)\n\n def prompt(self, results, bboxes=None, points=None, labels=None, texts=None):\n \"\"\"\n Perform image segmentation inference based on cues like bounding boxes, points, and text prompts.\n\n Args:\n results (Results | List[Results]): Original inference results from FastSAM models without any prompts.\n bboxes (np.ndarray | List, optional): Bounding boxes with shape (N, 4), in XYXY format.\n points (np.ndarray | List, optional): Points indicating object locations with shape (N, 2), in pixels.\n labels (np.ndarray | List, optional): Labels for point prompts, shape (N, ). 1 = foreground, 0 = background.\n texts (str | List[str], optional): Textual prompts, a list containing string objects.\n\n Returns:\n (List[Results]): Output results filtered and determined by the provided prompts.\n \"\"\"\n if bboxes is None and points is None and texts is None:\n return results\n prompt_results = []\n if not isinstance(results, list):\n results = [results]\n for result in results:\n if len(result) == 0:\n prompt_results.append(result)\n continue\n masks = result.masks.data\n if masks.shape[1:] != result.orig_shape:\n masks = scale_masks(masks[None], result.orig_shape)[0]\n # bboxes prompt\n idx = torch.zeros(len(result), dtype=torch.bool, device=self.device)\n if bboxes is not None:\n bboxes = torch.as_tensor(bboxes, dtype=torch.int32, device=self.device)\n bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes\n bbox_areas = (bboxes[:, 3] - bboxes[:, 1]) * (bboxes[:, 2] - bboxes[:, 0])\n mask_areas = torch.stack([masks[:, b[1] : b[3], b[0] : b[2]].sum(dim=(1, 2)) for b in bboxes])\n full_mask_areas = torch.sum(masks, dim=(1, 2))\n\n union = bbox_areas[:, None] + full_mask_areas - mask_areas\n idx[torch.argmax(mask_areas / union, dim=1)] = True\n if points is not None:\n points = torch.as_tensor(points, dtype=torch.int32, device=self.device)\n points = points[None] if points.ndim == 1 else points\n if labels is None:\n labels = torch.ones(points.shape[0])\n labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device)\n assert len(labels) == len(points), (\n f\"Expected `labels` with same size as `point`, but got {len(labels)} and {len(points)}\"\n )\n point_idx = (\n torch.ones(len(result), dtype=torch.bool, device=self.device)\n if labels.sum() == 0 # all negative points\n else torch.zeros(len(result), dtype=torch.bool, device=self.device)\n )\n for point, label in zip(points, labels):\n point_idx[torch.nonzero(masks[:, point[1], point[0]], as_tuple=True)[0]] = bool(label)\n idx |= point_idx\n if texts is not None:\n if isinstance(texts, str):\n texts = [texts]\n crop_ims, filter_idx = [], []\n for i, b in enumerate(result.boxes.xyxy.tolist()):\n x1, y1, x2, y2 = (int(x) for x in b)\n if masks[i].sum() <= 100:\n filter_idx.append(i)\n continue\n crop_ims.append(Image.fromarray(result.orig_img[y1:y2, x1:x2, ::-1]))\n similarity = self._clip_inference(crop_ims, texts)\n text_idx = torch.argmax(similarity, dim=-1) # (M, )\n if len(filter_idx):\n text_idx += (torch.tensor(filter_idx, device=self.device)[None] <= int(text_idx)).sum(0)\n idx[text_idx] = True\n\n prompt_results.append(result[idx])\n\n return prompt_results\n\n def _clip_inference(self, images, texts):\n \"\"\"\n Perform CLIP inference to calculate similarity between images and text prompts.\n\n Args:\n images (List[PIL.Image]): List of source images, each should be PIL.Image with RGB channel order.\n texts (List[str]): List of prompt texts, each should be a string object.\n\n Returns:\n (torch.Tensor): Similarity matrix between given images and texts with shape (M, N).\n \"\"\"\n try:\n import clip\n except ImportError:\n checks.check_requirements(\"git+https://github.com/ultralytics/CLIP.git\")\n import clip\n if (not hasattr(self, \"clip_model\")) or (not hasattr(self, \"clip_preprocess\")):\n self.clip_model, self.clip_preprocess = clip.load(\"ViT-B/32\", device=self.device)\n images = torch.stack([self.clip_preprocess(image).to(self.device) for image in images])\n tokenized_text = clip.tokenize(texts).to(self.device)\n image_features = self.clip_model.encode_image(images)\n text_features = self.clip_model.encode_text(tokenized_text)\n image_features /= image_features.norm(dim=-1, keepdim=True) # (N, 512)\n text_features /= text_features.norm(dim=-1, keepdim=True) # (M, 512)\n return (image_features * text_features[:, None]).sum(-1) # (M, N)\n\n def set_prompts(self, prompts):\n \"\"\"Set prompts to be used during inference.\"\"\"\n self.prompts = prompts", "chunk_type": "class", "name": "FastSAMPredictor", "file_path": "ultralytics\\ultralytics\\models\\fastsam\\predict.py", "start_line": 14, "end_line": 180, "start_col": 0, "end_col": 30, "parent_name": null, "docstring": "FastSAMPredictor is specialized for fast SAM (Segment Anything Model) segmentation prediction tasks.\n\nThis class extends the SegmentationPredictor, customizing the prediction pipeline specifically for fast SAM. It\nadjusts post-processing steps to incorporate mask prediction and non-maximum suppression while optimizing for\nsingle-class segmentation.\n\nAttributes:\n prompts (dict): Dictionary containing prompt information for segmentation (bboxes, points, labels, texts).\n device (torch.device): Device on which model and tensors are processed.\n clip_model (Any, optional): CLIP model for text-based prompting, loaded on demand.\n clip_preprocess (Any, optional): CLIP preprocessing function for images, loaded on demand.\n\nMethods:\n postprocess: Apply postprocessing to FastSAM predictions and handle prompts.\n prompt: Perform image segmentation inference based on various prompt types.\n set_prompts: Set prompts to be used during inference.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "torch", "PIL.Image", "ultralytics.models.yolo.segment.SegmentationPredictor", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.checks", "ultralytics.utils.metrics.box_iou", "ultralytics.utils.ops.scale_masks", "utils.adjust_bboxes_to_image_border", "clip", "clip", "SegmentationPredictor" ], "chunk_id": "class_FastSAMPredictor_16f795d4" }, { "content": "def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20):\n \"\"\"\n Adjust bounding boxes to stick to image border if they are within a certain threshold.\n\n Args:\n boxes (torch.Tensor): Bounding boxes with shape (N, 4) in xyxy format.\n image_shape (tuple): Image dimensions as (height, width).\n threshold (int): Pixel threshold for considering a box close to the border.\n\n Returns:\n (torch.Tensor): Adjusted bounding boxes with shape (N, 4).\n \"\"\"\n # Image dimensions\n h, w = image_shape\n\n # Adjust boxes that are close to image borders\n boxes[boxes[:, 0] < threshold, 0] = 0 # x1\n boxes[boxes[:, 1] < threshold, 1] = 0 # y1\n boxes[boxes[:, 2] > w - threshold, 2] = w # x2\n boxes[boxes[:, 3] > h - threshold, 3] = h # y2\n return boxes", "chunk_type": "function", "name": "adjust_bboxes_to_image_border", "file_path": "ultralytics\\ultralytics\\models\\fastsam\\utils.py", "start_line": 4, "end_line": 24, "start_col": 0, "end_col": 16, "parent_name": null, "docstring": "Adjust bounding boxes to stick to image border if they are within a certain threshold.\n\nArgs:\n boxes (torch.Tensor): Bounding boxes with shape (N, 4) in xyxy format.\n image_shape (tuple): Image dimensions as (height, width).\n threshold (int): Pixel threshold for considering a box close to the border.\n\nReturns:\n (torch.Tensor): Adjusted bounding boxes with shape (N, 4).", "parameters": [ "boxes", "image_shape", "threshold" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [], "chunk_id": "function_adjust_bboxes_to_image_border_775308ef" }, { "content": "from ultralytics.models.yolo.segment import SegmentationValidator", "chunk_type": "import", "name": "SegmentationValidator", "file_path": "ultralytics\\ultralytics\\models\\fastsam\\val.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 65, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_SegmentationValidator_3adc7638" }, { "content": "class FastSAMValidator(SegmentationValidator):\n \"\"\"\n Custom validation class for Fast SAM (Segment Anything Model) segmentation in Ultralytics YOLO framework.\n\n Extends the SegmentationValidator class, customizing the validation process specifically for Fast SAM. This class\n sets the task to 'segment' and uses the SegmentMetrics for evaluation. Additionally, plotting features are disabled\n to avoid errors during validation.\n\n Attributes:\n dataloader (torch.utils.data.DataLoader): The data loader object used for validation.\n save_dir (Path): The directory where validation results will be saved.\n args (SimpleNamespace): Additional arguments for customization of the validation process.\n _callbacks (list): List of callback functions to be invoked during validation.\n metrics (SegmentMetrics): Segmentation metrics calculator for evaluation.\n\n Methods:\n __init__: Initialize the FastSAMValidator with custom settings for Fast SAM.\n \"\"\"\n\n def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None):\n \"\"\"\n Initialize the FastSAMValidator class, setting the task to 'segment' and metrics to SegmentMetrics.\n\n Args:\n dataloader (torch.utils.data.DataLoader, optional): Dataloader to be used for validation.\n save_dir (Path, optional): Directory to save results.\n args (SimpleNamespace, optional): Configuration for the validator.\n _callbacks (list, optional): List of callback functions to be invoked during validation.\n\n Notes:\n Plots for ConfusionMatrix and other related metrics are disabled in this class to avoid errors.\n \"\"\"\n super().__init__(dataloader, save_dir, args, _callbacks)\n self.args.task = \"segment\"\n self.args.plots = False # disable ConfusionMatrix and other plots to avoid errors", "chunk_type": "class", "name": "FastSAMValidator", "file_path": "ultralytics\\ultralytics\\models\\fastsam\\val.py", "start_line": 6, "end_line": 40, "start_col": 0, "end_col": 31, "parent_name": null, "docstring": "Custom validation class for Fast SAM (Segment Anything Model) segmentation in Ultralytics YOLO framework.\n\nExtends the SegmentationValidator class, customizing the validation process specifically for Fast SAM. This class\nsets the task to 'segment' and uses the SegmentMetrics for evaluation. Additionally, plotting features are disabled\nto avoid errors during validation.\n\nAttributes:\n dataloader (torch.utils.data.DataLoader): The data loader object used for validation.\n save_dir (Path): The directory where validation results will be saved.\n args (SimpleNamespace): Additional arguments for customization of the validation process.\n _callbacks (list): List of callback functions to be invoked during validation.\n metrics (SegmentMetrics): Segmentation metrics calculator for evaluation.\n\nMethods:\n __init__: Initialize the FastSAMValidator with custom settings for Fast SAM.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "ultralytics.models.yolo.segment.SegmentationValidator", "SegmentationValidator" ], "chunk_id": "class_FastSAMValidator_f79634ac" }, { "content": "from .model import FastSAM", "chunk_type": "import", "name": "FastSAM", "file_path": "ultralytics\\ultralytics\\models\\fastsam\\__init__.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 26, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_FastSAM_596b8d4c" }, { "content": "from .predict import FastSAMPredictor", "chunk_type": "import", "name": "FastSAMPredictor", "file_path": "ultralytics\\ultralytics\\models\\fastsam\\__init__.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 37, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_FastSAMPredictor_a37ddd48" }, { "content": "from .val import FastSAMValidator", "chunk_type": "import", "name": "FastSAMValidator", "file_path": "ultralytics\\ultralytics\\models\\fastsam\\__init__.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 33, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_FastSAMValidator_dce83088" }, { "content": "__all__ = \"FastSAMPredictor\", \"FastSAM\", \"FastSAMValidator\"", "chunk_type": "variable", "name": "__all__", "file_path": "ultralytics\\ultralytics\\models\\fastsam\\__init__.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 59, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable___all___caf38e02" }, { "content": "from pathlib import Path", "chunk_type": "import", "name": "Path", "file_path": "ultralytics\\ultralytics\\models\\nas\\model.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Path_63a1b45c" }, { "content": "from typing import Any, Dict", "chunk_type": "import", "name": "Any, Dict", "file_path": "ultralytics\\ultralytics\\models\\nas\\model.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 28, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Any, Dict_867260ee" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\ultralytics\\models\\nas\\model.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_d26b0c34" }, { "content": "from ultralytics.engine.model import Model", "chunk_type": "import", "name": "Model", "file_path": "ultralytics\\ultralytics\\models\\nas\\model.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 42, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Model_16a03cde" }, { "content": "from ultralytics.utils import DEFAULT_CFG_DICT", "chunk_type": "import", "name": "DEFAULT_CFG_DICT", "file_path": "ultralytics\\ultralytics\\models\\nas\\model.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 46, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_DEFAULT_CFG_DICT_d978629b" }, { "content": "from ultralytics.utils.downloads import attempt_download_asset", "chunk_type": "import", "name": "attempt_download_asset", "file_path": "ultralytics\\ultralytics\\models\\nas\\model.py", "start_line": 10, "end_line": 10, "start_col": 0, "end_col": 62, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_attempt_download_asset_cfb43222" }, { "content": "from ultralytics.utils.patches import torch_load", "chunk_type": "import", "name": "torch_load", "file_path": "ultralytics\\ultralytics\\models\\nas\\model.py", "start_line": 11, "end_line": 11, "start_col": 0, "end_col": 48, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_load_3c9fda17" }, { "content": "from ultralytics.utils.torch_utils import model_info", "chunk_type": "import", "name": "model_info", "file_path": "ultralytics\\ultralytics\\models\\nas\\model.py", "start_line": 12, "end_line": 12, "start_col": 0, "end_col": 52, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_model_info_c8b686cd" }, { "content": "from .predict import NASPredictor", "chunk_type": "import", "name": "NASPredictor", "file_path": "ultralytics\\ultralytics\\models\\nas\\model.py", "start_line": 14, "end_line": 14, "start_col": 0, "end_col": 33, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_NASPredictor_30757246" }, { "content": "from .val import NASValidator", "chunk_type": "import", "name": "NASValidator", "file_path": "ultralytics\\ultralytics\\models\\nas\\model.py", "start_line": 15, "end_line": 15, "start_col": 0, "end_col": 29, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_NASValidator_60ddb312" }, { "content": "class NAS(Model):\n \"\"\"\n YOLO-NAS model for object detection.\n\n This class provides an interface for the YOLO-NAS models and extends the `Model` class from Ultralytics engine.\n It is designed to facilitate the task of object detection using pre-trained or custom-trained YOLO-NAS models.\n\n Attributes:\n model (torch.nn.Module): The loaded YOLO-NAS model.\n task (str): The task type for the model, defaults to 'detect'.\n predictor (NASPredictor): The predictor instance for making predictions.\n validator (NASValidator): The validator instance for model validation.\n\n Methods:\n info: Log model information and return model details.\n\n Examples:\n >>> from ultralytics import NAS\n >>> model = NAS(\"yolo_nas_s\")\n >>> results = model.predict(\"ultralytics/assets/bus.jpg\")\n\n Notes:\n YOLO-NAS models only support pre-trained models. Do not provide YAML configuration files.\n \"\"\"\n\n def __init__(self, model: str = \"yolo_nas_s.pt\") -> None:\n \"\"\"Initialize the NAS model with the provided or default model.\"\"\"\n assert Path(model).suffix not in {\".yaml\", \".yml\"}, \"YOLO-NAS models only support pre-trained models.\"\n super().__init__(model, task=\"detect\")\n\n def _load(self, weights: str, task=None) -> None:\n \"\"\"\n Load an existing NAS model weights or create a new NAS model with pretrained weights.\n\n Args:\n weights (str): Path to the model weights file or model name.\n task (str, optional): Task type for the model.\n \"\"\"\n import super_gradients\n\n suffix = Path(weights).suffix\n if suffix == \".pt\":\n self.model = torch_load(attempt_download_asset(weights))\n elif suffix == \"\":\n self.model = super_gradients.training.models.get(weights, pretrained_weights=\"coco\")\n\n # Override the forward method to ignore additional arguments\n def new_forward(x, *args, **kwargs):\n \"\"\"Ignore additional __call__ arguments.\"\"\"\n return self.model._original_forward(x)\n\n self.model._original_forward = self.model.forward\n self.model.forward = new_forward\n\n # Standardize model attributes for compatibility\n self.model.fuse = lambda verbose=True: self.model\n self.model.stride = torch.tensor([32])\n self.model.names = dict(enumerate(self.model._class_names))\n self.model.is_fused = lambda: False # for info()\n self.model.yaml = {} # for info()\n self.model.pt_path = weights # for export()\n self.model.task = \"detect\" # for export()\n self.model.args = {**DEFAULT_CFG_DICT, **self.overrides} # for export()\n self.model.eval()\n\n def info(self, detailed: bool = False, verbose: bool = True) -> Dict[str, Any]:\n \"\"\"\n Log model information.\n\n Args:\n detailed (bool): Show detailed information about model.\n verbose (bool): Controls verbosity.\n\n Returns:\n (Dict[str, Any]): Model information dictionary.\n \"\"\"\n return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640)\n\n @property\n def task_map(self) -> Dict[str, Dict[str, Any]]:\n \"\"\"Return a dictionary mapping tasks to respective predictor and validator classes.\"\"\"\n return {\"detect\": {\"predictor\": NASPredictor, \"validator\": NASValidator}}", "chunk_type": "class", "name": "NAS", "file_path": "ultralytics\\ultralytics\\models\\nas\\model.py", "start_line": 18, "end_line": 99, "start_col": 0, "end_col": 81, "parent_name": null, "docstring": "YOLO-NAS model for object detection.\n\nThis class provides an interface for the YOLO-NAS models and extends the `Model` class from Ultralytics engine.\nIt is designed to facilitate the task of object detection using pre-trained or custom-trained YOLO-NAS models.\n\nAttributes:\n model (torch.nn.Module): The loaded YOLO-NAS model.\n task (str): The task type for the model, defaults to 'detect'.\n predictor (NASPredictor): The predictor instance for making predictions.\n validator (NASValidator): The validator instance for model validation.\n\nMethods:\n info: Log model information and return model details.\n\nExamples:\n >>> from ultralytics import NAS\n >>> model = NAS(\"yolo_nas_s\")\n >>> results = model.predict(\"ultralytics/assets/bus.jpg\")\n\nNotes:\n YOLO-NAS models only support pre-trained models. Do not provide YAML configuration files.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "pathlib.Path", "typing.Any", "typing.Dict", "torch", "ultralytics.engine.model.Model", "ultralytics.utils.DEFAULT_CFG_DICT", "ultralytics.utils.downloads.attempt_download_asset", "ultralytics.utils.patches.torch_load", "ultralytics.utils.torch_utils.model_info", "predict.NASPredictor", "val.NASValidator", "super_gradients", "Model" ], "chunk_id": "class_NAS_ce1430e4" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\ultralytics\\models\\nas\\predict.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_cb8909b1" }, { "content": "from ultralytics.models.yolo.detect.predict import DetectionPredictor", "chunk_type": "import", "name": "DetectionPredictor", "file_path": "ultralytics\\ultralytics\\models\\nas\\predict.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 69, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_DetectionPredictor_8fa7d599" }, { "content": "from ultralytics.utils import ops", "chunk_type": "import", "name": "ops", "file_path": "ultralytics\\ultralytics\\models\\nas\\predict.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 33, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_ops_7a866dcc" }, { "content": "class NASPredictor(DetectionPredictor):\n \"\"\"\n Ultralytics YOLO NAS Predictor for object detection.\n\n This class extends the DetectionPredictor from Ultralytics engine and is responsible for post-processing the\n raw predictions generated by the YOLO NAS models. It applies operations like non-maximum suppression and\n scaling the bounding boxes to fit the original image dimensions.\n\n Attributes:\n args (Namespace): Namespace containing various configurations for post-processing including confidence\n threshold, IoU threshold, agnostic NMS flag, maximum detections, and class filtering options.\n model (torch.nn.Module): The YOLO NAS model used for inference.\n batch (list): Batch of inputs for processing.\n\n Examples:\n >>> from ultralytics import NAS\n >>> model = NAS(\"yolo_nas_s\")\n >>> predictor = model.predictor\n\n Assume that raw_preds, img, orig_imgs are available\n >>> results = predictor.postprocess(raw_preds, img, orig_imgs)\n\n Notes:\n Typically, this class is not instantiated directly. It is used internally within the NAS class.\n \"\"\"\n\n def postprocess(self, preds_in, img, orig_imgs):\n \"\"\"\n Postprocess NAS model predictions to generate final detection results.\n\n This method takes raw predictions from a YOLO NAS model, converts bounding box formats, and applies\n post-processing operations to generate the final detection results compatible with Ultralytics\n result visualization and analysis tools.\n\n Args:\n preds_in (list): Raw predictions from the NAS model, typically containing bounding boxes and class scores.\n img (torch.Tensor): Input image tensor that was fed to the model, with shape (B, C, H, W).\n orig_imgs (list | torch.Tensor | np.ndarray): Original images before preprocessing, used for scaling\n coordinates back to original dimensions.\n\n Returns:\n (list): List of Results objects containing the processed predictions for each image in the batch.\n\n Examples:\n >>> predictor = NAS(\"yolo_nas_s\").predictor\n >>> results = predictor.postprocess(raw_preds, img, orig_imgs)\n \"\"\"\n boxes = ops.xyxy2xywh(preds_in[0][0]) # Convert bounding boxes from xyxy to xywh format\n preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1) # Concatenate boxes with class scores\n return super().postprocess(preds, img, orig_imgs)", "chunk_type": "class", "name": "NASPredictor", "file_path": "ultralytics\\ultralytics\\models\\nas\\predict.py", "start_line": 9, "end_line": 58, "start_col": 0, "end_col": 57, "parent_name": null, "docstring": "Ultralytics YOLO NAS Predictor for object detection.\n\nThis class extends the DetectionPredictor from Ultralytics engine and is responsible for post-processing the\nraw predictions generated by the YOLO NAS models. It applies operations like non-maximum suppression and\nscaling the bounding boxes to fit the original image dimensions.\n\nAttributes:\n args (Namespace): Namespace containing various configurations for post-processing including confidence\n threshold, IoU threshold, agnostic NMS flag, maximum detections, and class filtering options.\n model (torch.nn.Module): The YOLO NAS model used for inference.\n batch (list): Batch of inputs for processing.\n\nExamples:\n >>> from ultralytics import NAS\n >>> model = NAS(\"yolo_nas_s\")\n >>> predictor = model.predictor\n\n Assume that raw_preds, img, orig_imgs are available\n >>> results = predictor.postprocess(raw_preds, img, orig_imgs)\n\nNotes:\n Typically, this class is not instantiated directly. It is used internally within the NAS class.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "torch", "ultralytics.models.yolo.detect.predict.DetectionPredictor", "ultralytics.utils.ops", "DetectionPredictor" ], "chunk_id": "class_NASPredictor_68ba2ed5" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\ultralytics\\models\\nas\\val.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_e167895d" }, { "content": "from ultralytics.models.yolo.detect import DetectionValidator", "chunk_type": "import", "name": "DetectionValidator", "file_path": "ultralytics\\ultralytics\\models\\nas\\val.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 61, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_DetectionValidator_3c78c9ab" }, { "content": "from ultralytics.utils import ops", "chunk_type": "import", "name": "ops", "file_path": "ultralytics\\ultralytics\\models\\nas\\val.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 33, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_ops_1d91d0fa" }, { "content": "__all__ = [\"NASValidator\"]", "chunk_type": "variable", "name": "__all__", "file_path": "ultralytics\\ultralytics\\models\\nas\\val.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 26, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable___all___72cf61b7" }, { "content": "class NASValidator(DetectionValidator):\n \"\"\"\n Ultralytics YOLO NAS Validator for object detection.\n\n Extends DetectionValidator from the Ultralytics models package and is designed to post-process the raw predictions\n generated by YOLO NAS models. It performs non-maximum suppression to remove overlapping and low-confidence boxes,\n ultimately producing the final detections.\n\n Attributes:\n args (Namespace): Namespace containing various configurations for post-processing, such as confidence and IoU\n thresholds.\n lb (torch.Tensor): Optional tensor for multilabel NMS.\n\n Examples:\n >>> from ultralytics import NAS\n >>> model = NAS(\"yolo_nas_s\")\n >>> validator = model.validator\n >>> # Assumes that raw_preds are available\n >>> final_preds = validator.postprocess(raw_preds)\n\n Notes:\n This class is generally not instantiated directly but is used internally within the NAS class.\n \"\"\"\n\n def postprocess(self, preds_in):\n \"\"\"Apply Non-maximum suppression to prediction outputs.\"\"\"\n boxes = ops.xyxy2xywh(preds_in[0][0]) # Convert bounding box format from xyxy to xywh\n preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1) # Concatenate boxes with scores and permute\n return super().postprocess(preds)", "chunk_type": "class", "name": "NASValidator", "file_path": "ultralytics\\ultralytics\\models\\nas\\val.py", "start_line": 11, "end_line": 39, "start_col": 0, "end_col": 41, "parent_name": null, "docstring": "Ultralytics YOLO NAS Validator for object detection.\n\nExtends DetectionValidator from the Ultralytics models package and is designed to post-process the raw predictions\ngenerated by YOLO NAS models. It performs non-maximum suppression to remove overlapping and low-confidence boxes,\nultimately producing the final detections.\n\nAttributes:\n args (Namespace): Namespace containing various configurations for post-processing, such as confidence and IoU\n thresholds.\n lb (torch.Tensor): Optional tensor for multilabel NMS.\n\nExamples:\n >>> from ultralytics import NAS\n >>> model = NAS(\"yolo_nas_s\")\n >>> validator = model.validator\n >>> # Assumes that raw_preds are available\n >>> final_preds = validator.postprocess(raw_preds)\n\nNotes:\n This class is generally not instantiated directly but is used internally within the NAS class.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "torch", "ultralytics.models.yolo.detect.DetectionValidator", "ultralytics.utils.ops", "DetectionValidator" ], "chunk_id": "class_NASValidator_c323d399" }, { "content": "from .model import NAS", "chunk_type": "import", "name": "NAS", "file_path": "ultralytics\\ultralytics\\models\\nas\\__init__.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 22, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_NAS_f3103ce8" }, { "content": "from .predict import NASPredictor", "chunk_type": "import", "name": "NASPredictor", "file_path": "ultralytics\\ultralytics\\models\\nas\\__init__.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 33, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_NASPredictor_e55f4165" }, { "content": "from .val import NASValidator", "chunk_type": "import", "name": "NASValidator", "file_path": "ultralytics\\ultralytics\\models\\nas\\__init__.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 29, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_NASValidator_38a1f670" }, { "content": "__all__ = \"NASPredictor\", \"NASValidator\", \"NAS\"", "chunk_type": "variable", "name": "__all__", "file_path": "ultralytics\\ultralytics\\models\\nas\\__init__.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 47, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable___all___080d624a" }, { "content": "from ultralytics.engine.model import Model", "chunk_type": "import", "name": "Model", "file_path": "ultralytics\\ultralytics\\models\\rtdetr\\model.py", "start_line": 12, "end_line": 12, "start_col": 0, "end_col": 42, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Model_70dd2088" }, { "content": "from ultralytics.nn.tasks import RTDETRDetectionModel", "chunk_type": "import", "name": "RTDETRDetectionModel", "file_path": "ultralytics\\ultralytics\\models\\rtdetr\\model.py", "start_line": 13, "end_line": 13, "start_col": 0, "end_col": 53, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_RTDETRDetectionModel_df62a2a5" }, { "content": "from .predict import RTDETRPredictor", "chunk_type": "import", "name": "RTDETRPredictor", "file_path": "ultralytics\\ultralytics\\models\\rtdetr\\model.py", "start_line": 15, "end_line": 15, "start_col": 0, "end_col": 36, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_RTDETRPredictor_f1d87ecc" }, { "content": "from .train import RTDETRTrainer", "chunk_type": "import", "name": "RTDETRTrainer", "file_path": "ultralytics\\ultralytics\\models\\rtdetr\\model.py", "start_line": 16, "end_line": 16, "start_col": 0, "end_col": 32, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_RTDETRTrainer_a24dc92e" }, { "content": "from .val import RTDETRValidator", "chunk_type": "import", "name": "RTDETRValidator", "file_path": "ultralytics\\ultralytics\\models\\rtdetr\\model.py", "start_line": 17, "end_line": 17, "start_col": 0, "end_col": 32, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_RTDETRValidator_c77a4be2" }, { "content": "class RTDETR(Model):\n \"\"\"\n Interface for Baidu's RT-DETR model, a Vision Transformer-based real-time object detector.\n\n This model provides real-time performance with high accuracy. It supports efficient hybrid encoding, IoU-aware\n query selection, and adaptable inference speed.\n\n Attributes:\n model (str): Path to the pre-trained model.\n\n Methods:\n task_map: Return a task map for RT-DETR, associating tasks with corresponding Ultralytics classes.\n\n Examples:\n Initialize RT-DETR with a pre-trained model\n >>> from ultralytics import RTDETR\n >>> model = RTDETR(\"rtdetr-l.pt\")\n >>> results = model(\"image.jpg\")\n \"\"\"\n\n def __init__(self, model: str = \"rtdetr-l.pt\") -> None:\n \"\"\"\n Initialize the RT-DETR model with the given pre-trained model file.\n\n Args:\n model (str): Path to the pre-trained model. Supports .pt, .yaml, and .yml formats.\n \"\"\"\n super().__init__(model=model, task=\"detect\")\n\n @property\n def task_map(self) -> dict:\n \"\"\"\n Return a task map for RT-DETR, associating tasks with corresponding Ultralytics classes.\n\n Returns:\n (dict): A dictionary mapping task names to Ultralytics task classes for the RT-DETR model.\n \"\"\"\n return {\n \"detect\": {\n \"predictor\": RTDETRPredictor,\n \"validator\": RTDETRValidator,\n \"trainer\": RTDETRTrainer,\n \"model\": RTDETRDetectionModel,\n }\n }", "chunk_type": "class", "name": "RTDETR", "file_path": "ultralytics\\ultralytics\\models\\rtdetr\\model.py", "start_line": 20, "end_line": 64, "start_col": 0, "end_col": 9, "parent_name": null, "docstring": "Interface for Baidu's RT-DETR model, a Vision Transformer-based real-time object detector.\n\nThis model provides real-time performance with high accuracy. It supports efficient hybrid encoding, IoU-aware\nquery selection, and adaptable inference speed.\n\nAttributes:\n model (str): Path to the pre-trained model.\n\nMethods:\n task_map: Return a task map for RT-DETR, associating tasks with corresponding Ultralytics classes.\n\nExamples:\n Initialize RT-DETR with a pre-trained model\n >>> from ultralytics import RTDETR\n >>> model = RTDETR(\"rtdetr-l.pt\")\n >>> results = model(\"image.jpg\")", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "ultralytics.engine.model.Model", "ultralytics.nn.tasks.RTDETRDetectionModel", "predict.RTDETRPredictor", "train.RTDETRTrainer", "val.RTDETRValidator", "Model" ], "chunk_id": "class_RTDETR_3dbd933c" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\ultralytics\\models\\rtdetr\\predict.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_bdcfae9f" }, { "content": "from ultralytics.data.augment import LetterBox", "chunk_type": "import", "name": "LetterBox", "file_path": "ultralytics\\ultralytics\\models\\rtdetr\\predict.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 46, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_LetterBox_58b6ab82" }, { "content": "from ultralytics.engine.predictor import BasePredictor", "chunk_type": "import", "name": "BasePredictor", "file_path": "ultralytics\\ultralytics\\models\\rtdetr\\predict.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 54, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_BasePredictor_c18cc8ab" }, { "content": "from ultralytics.engine.results import Results", "chunk_type": "import", "name": "Results", "file_path": "ultralytics\\ultralytics\\models\\rtdetr\\predict.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 46, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Results_c349127c" }, { "content": "from ultralytics.utils import ops", "chunk_type": "import", "name": "ops", "file_path": "ultralytics\\ultralytics\\models\\rtdetr\\predict.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 33, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_ops_6f996d74" }, { "content": "class RTDETRPredictor(BasePredictor):\n \"\"\"\n RT-DETR (Real-Time Detection Transformer) Predictor extending the BasePredictor class for making predictions.\n\n This class leverages Vision Transformers to provide real-time object detection while maintaining high accuracy.\n It supports key features like efficient hybrid encoding and IoU-aware query selection.\n\n Attributes:\n imgsz (int): Image size for inference (must be square and scale-filled).\n args (dict): Argument overrides for the predictor.\n model (torch.nn.Module): The loaded RT-DETR model.\n batch (list): Current batch of processed inputs.\n\n Methods:\n postprocess: Postprocess raw model predictions to generate bounding boxes and confidence scores.\n pre_transform: Pre-transform input images before feeding them into the model for inference.\n\n Examples:\n >>> from ultralytics.utils import ASSETS\n >>> from ultralytics.models.rtdetr import RTDETRPredictor\n >>> args = dict(model=\"rtdetr-l.pt\", source=ASSETS)\n >>> predictor = RTDETRPredictor(overrides=args)\n >>> predictor.predict_cli()\n \"\"\"\n\n def postprocess(self, preds, img, orig_imgs):\n \"\"\"\n Postprocess the raw predictions from the model to generate bounding boxes and confidence scores.\n\n The method filters detections based on confidence and class if specified in `self.args`. It converts\n model predictions to Results objects containing properly scaled bounding boxes.\n\n Args:\n preds (list | tuple): List of [predictions, extra] from the model, where predictions contain\n bounding boxes and scores.\n img (torch.Tensor): Processed input images with shape (N, 3, H, W).\n orig_imgs (list | torch.Tensor): Original, unprocessed images.\n\n Returns:\n results (List[Results]): A list of Results objects containing the post-processed bounding boxes,\n confidence scores, and class labels.\n \"\"\"\n if not isinstance(preds, (list, tuple)): # list for PyTorch inference but list[0] Tensor for export inference\n preds = [preds, None]\n\n nd = preds[0].shape[-1]\n bboxes, scores = preds[0].split((4, nd - 4), dim=-1)\n\n if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list\n orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)\n\n results = []\n for bbox, score, orig_img, img_path in zip(bboxes, scores, orig_imgs, self.batch[0]): # (300, 4)\n bbox = ops.xywh2xyxy(bbox)\n max_score, cls = score.max(-1, keepdim=True) # (300, 1)\n idx = max_score.squeeze(-1) > self.args.conf # (300, )\n if self.args.classes is not None:\n idx = (cls == torch.tensor(self.args.classes, device=cls.device)).any(1) & idx\n pred = torch.cat([bbox, max_score, cls], dim=-1)[idx] # filter\n oh, ow = orig_img.shape[:2]\n pred[..., [0, 2]] *= ow # scale x coordinates to original width\n pred[..., [1, 3]] *= oh # scale y coordinates to original height\n results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred))\n return results\n\n def pre_transform(self, im):\n \"\"\"\n Pre-transform input images before feeding them into the model for inference.\n\n The input images are letterboxed to ensure a square aspect ratio and scale-filled. The size must be square\n (640) and scale_filled.\n\n Args:\n im (List[np.ndarray] | torch.Tensor): Input images of shape (N, 3, H, W) for tensor,\n [(H, W, 3) x N] for list.\n\n Returns:\n (list): List of pre-transformed images ready for model inference.\n \"\"\"\n letterbox = LetterBox(self.imgsz, auto=False, scale_fill=True)\n return [letterbox(image=x) for x in im]", "chunk_type": "class", "name": "RTDETRPredictor", "file_path": "ultralytics\\ultralytics\\models\\rtdetr\\predict.py", "start_line": 11, "end_line": 91, "start_col": 0, "end_col": 47, "parent_name": null, "docstring": "RT-DETR (Real-Time Detection Transformer) Predictor extending the BasePredictor class for making predictions.\n\nThis class leverages Vision Transformers to provide real-time object detection while maintaining high accuracy.\nIt supports key features like efficient hybrid encoding and IoU-aware query selection.\n\nAttributes:\n imgsz (int): Image size for inference (must be square and scale-filled).\n args (dict): Argument overrides for the predictor.\n model (torch.nn.Module): The loaded RT-DETR model.\n batch (list): Current batch of processed inputs.\n\nMethods:\n postprocess: Postprocess raw model predictions to generate bounding boxes and confidence scores.\n pre_transform: Pre-transform input images before feeding them into the model for inference.\n\nExamples:\n >>> from ultralytics.utils import ASSETS\n >>> from ultralytics.models.rtdetr import RTDETRPredictor\n >>> args = dict(model=\"rtdetr-l.pt\", source=ASSETS)\n >>> predictor = RTDETRPredictor(overrides=args)\n >>> predictor.predict_cli()", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "torch", "ultralytics.data.augment.LetterBox", "ultralytics.engine.predictor.BasePredictor", "ultralytics.engine.results.Results", "ultralytics.utils.ops", "BasePredictor" ], "chunk_id": "class_RTDETRPredictor_781515ed" }, { "content": "from copy import copy", "chunk_type": "import", "name": "copy", "file_path": "ultralytics\\ultralytics\\models\\rtdetr\\train.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 21, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_copy_8c82df20" }, { "content": "from typing import Optional", "chunk_type": "import", "name": "Optional", "file_path": "ultralytics\\ultralytics\\models\\rtdetr\\train.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 27, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Optional_818d31a9" }, { "content": "from ultralytics.models.yolo.detect import DetectionTrainer", "chunk_type": "import", "name": "DetectionTrainer", "file_path": "ultralytics\\ultralytics\\models\\rtdetr\\train.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 59, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_DetectionTrainer_51568663" }, { "content": "from ultralytics.nn.tasks import RTDETRDetectionModel", "chunk_type": "import", "name": "RTDETRDetectionModel", "file_path": "ultralytics\\ultralytics\\models\\rtdetr\\train.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 53, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_RTDETRDetectionModel_8a37c91d" }, { "content": "from ultralytics.utils import RANK, colorstr", "chunk_type": "import", "name": "RANK, colorstr", "file_path": "ultralytics\\ultralytics\\models\\rtdetr\\train.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 44, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_RANK, colorstr_0b78d91a" }, { "content": "from .val import RTDETRDataset, RTDETRValidator", "chunk_type": "import", "name": "RTDETRDataset, RTDETRValidator", "file_path": "ultralytics\\ultralytics\\models\\rtdetr\\train.py", "start_line": 10, "end_line": 10, "start_col": 0, "end_col": 47, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_RTDETRDataset, RTDETRValidator_a803f6b9" }, { "content": "class RTDETRTrainer(DetectionTrainer):\n \"\"\"\n Trainer class for the RT-DETR model developed by Baidu for real-time object detection.\n\n This class extends the DetectionTrainer class for YOLO to adapt to the specific features and architecture of RT-DETR.\n The model leverages Vision Transformers and has capabilities like IoU-aware query selection and adaptable inference\n speed.\n\n Attributes:\n loss_names (tuple): Names of the loss components used for training.\n data (dict): Dataset configuration containing class count and other parameters.\n args (dict): Training arguments and hyperparameters.\n save_dir (Path): Directory to save training results.\n test_loader (DataLoader): DataLoader for validation/testing data.\n\n Methods:\n get_model: Initialize and return an RT-DETR model for object detection tasks.\n build_dataset: Build and return an RT-DETR dataset for training or validation.\n get_validator: Return a DetectionValidator suitable for RT-DETR model validation.\n\n Notes:\n - F.grid_sample used in RT-DETR does not support the `deterministic=True` argument.\n - AMP training can lead to NaN outputs and may produce errors during bipartite graph matching.\n\n Examples:\n >>> from ultralytics.models.rtdetr.train import RTDETRTrainer\n >>> args = dict(model=\"rtdetr-l.yaml\", data=\"coco8.yaml\", imgsz=640, epochs=3)\n >>> trainer = RTDETRTrainer(overrides=args)\n >>> trainer.train()\n \"\"\"\n\n def get_model(self, cfg: Optional[dict] = None, weights: Optional[str] = None, verbose: bool = True):\n \"\"\"\n Initialize and return an RT-DETR model for object detection tasks.\n\n Args:\n cfg (dict, optional): Model configuration.\n weights (str, optional): Path to pre-trained model weights.\n verbose (bool): Verbose logging if True.\n\n Returns:\n (RTDETRDetectionModel): Initialized model.\n \"\"\"\n model = RTDETRDetectionModel(cfg, nc=self.data[\"nc\"], ch=self.data[\"channels\"], verbose=verbose and RANK == -1)\n if weights:\n model.load(weights)\n return model\n\n def build_dataset(self, img_path: str, mode: str = \"val\", batch: Optional[int] = None):\n \"\"\"\n Build and return an RT-DETR dataset for training or validation.\n\n Args:\n img_path (str): Path to the folder containing images.\n mode (str): Dataset mode, either 'train' or 'val'.\n batch (int, optional): Batch size for rectangle training.\n\n Returns:\n (RTDETRDataset): Dataset object for the specific mode.\n \"\"\"\n return RTDETRDataset(\n img_path=img_path,\n imgsz=self.args.imgsz,\n batch_size=batch,\n augment=mode == \"train\",\n hyp=self.args,\n rect=False,\n cache=self.args.cache or None,\n single_cls=self.args.single_cls or False,\n prefix=colorstr(f\"{mode}: \"),\n classes=self.args.classes,\n data=self.data,\n fraction=self.args.fraction if mode == \"train\" else 1.0,\n )\n\n def get_validator(self):\n \"\"\"Return a DetectionValidator suitable for RT-DETR model validation.\"\"\"\n self.loss_names = \"giou_loss\", \"cls_loss\", \"l1_loss\"\n return RTDETRValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))", "chunk_type": "class", "name": "RTDETRTrainer", "file_path": "ultralytics\\ultralytics\\models\\rtdetr\\train.py", "start_line": 13, "end_line": 91, "start_col": 0, "end_col": 94, "parent_name": null, "docstring": "Trainer class for the RT-DETR model developed by Baidu for real-time object detection.\n\nThis class extends the DetectionTrainer class for YOLO to adapt to the specific features and architecture of RT-DETR.\nThe model leverages Vision Transformers and has capabilities like IoU-aware query selection and adaptable inference\nspeed.\n\nAttributes:\n loss_names (tuple): Names of the loss components used for training.\n data (dict): Dataset configuration containing class count and other parameters.\n args (dict): Training arguments and hyperparameters.\n save_dir (Path): Directory to save training results.\n test_loader (DataLoader): DataLoader for validation/testing data.\n\nMethods:\n get_model: Initialize and return an RT-DETR model for object detection tasks.\n build_dataset: Build and return an RT-DETR dataset for training or validation.\n get_validator: Return a DetectionValidator suitable for RT-DETR model validation.\n\nNotes:\n - F.grid_sample used in RT-DETR does not support the `deterministic=True` argument.\n - AMP training can lead to NaN outputs and may produce errors during bipartite graph matching.\n\nExamples:\n >>> from ultralytics.models.rtdetr.train import RTDETRTrainer\n >>> args = dict(model=\"rtdetr-l.yaml\", data=\"coco8.yaml\", imgsz=640, epochs=3)\n >>> trainer = RTDETRTrainer(overrides=args)\n >>> trainer.train()", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "copy.copy", "typing.Optional", "ultralytics.models.yolo.detect.DetectionTrainer", "ultralytics.nn.tasks.RTDETRDetectionModel", "ultralytics.utils.RANK", "ultralytics.utils.colorstr", "val.RTDETRDataset", "val.RTDETRValidator", "DetectionTrainer" ], "chunk_id": "class_RTDETRTrainer_68d1f2ef" }, { "content": "from typing import Any, Dict, List, Tuple, Union", "chunk_type": "import", "name": "Any, Dict, List, Tuple, Union", "file_path": "ultralytics\\ultralytics\\models\\rtdetr\\val.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 48, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Any, Dict, List, Tuple, Union_ddc2db15" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\ultralytics\\models\\rtdetr\\val.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_d7d425c3" }, { "content": "from ultralytics.data import YOLODataset", "chunk_type": "import", "name": "YOLODataset", "file_path": "ultralytics\\ultralytics\\models\\rtdetr\\val.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 40, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_YOLODataset_b73f5263" }, { "content": "from ultralytics.data.augment import Compose, Format, v8_transforms", "chunk_type": "import", "name": "Compose, Format, v8_transforms", "file_path": "ultralytics\\ultralytics\\models\\rtdetr\\val.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 67, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Compose, Format, v8_transforms_82ece729" }, { "content": "from ultralytics.models.yolo.detect import DetectionValidator", "chunk_type": "import", "name": "DetectionValidator", "file_path": "ultralytics\\ultralytics\\models\\rtdetr\\val.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 61, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_DetectionValidator_ac673e65" }, { "content": "from ultralytics.utils import colorstr, ops", "chunk_type": "import", "name": "colorstr, ops", "file_path": "ultralytics\\ultralytics\\models\\rtdetr\\val.py", "start_line": 10, "end_line": 10, "start_col": 0, "end_col": 43, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_colorstr, ops_2d167c92" }, { "content": "__all__ = (\"RTDETRValidator\",) # tuple or list", "chunk_type": "variable", "name": "__all__", "file_path": "ultralytics\\ultralytics\\models\\rtdetr\\val.py", "start_line": 12, "end_line": 12, "start_col": 0, "end_col": 30, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable___all___376b9aa0" }, { "content": "class RTDETRDataset(YOLODataset):\n \"\"\"\n Real-Time DEtection and TRacking (RT-DETR) dataset class extending the base YOLODataset class.\n\n This specialized dataset class is designed for use with the RT-DETR object detection model and is optimized for\n real-time detection and tracking tasks.\n\n Attributes:\n augment (bool): Whether to apply data augmentation.\n rect (bool): Whether to use rectangular training.\n use_segments (bool): Whether to use segmentation masks.\n use_keypoints (bool): Whether to use keypoint annotations.\n imgsz (int): Target image size for training.\n\n Methods:\n load_image: Load one image from dataset index.\n build_transforms: Build transformation pipeline for the dataset.\n\n Examples:\n Initialize an RT-DETR dataset\n >>> dataset = RTDETRDataset(img_path=\"path/to/images\", imgsz=640)\n >>> image, hw = dataset.load_image(0)\n \"\"\"\n\n def __init__(self, *args, data=None, **kwargs):\n \"\"\"\n Initialize the RTDETRDataset class by inheriting from the YOLODataset class.\n\n This constructor sets up a dataset specifically optimized for the RT-DETR (Real-Time DEtection and TRacking)\n model, building upon the base YOLODataset functionality.\n\n Args:\n *args (Any): Variable length argument list passed to the parent YOLODataset class.\n data (dict | None): Dictionary containing dataset information. If None, default values will be used.\n **kwargs (Any): Additional keyword arguments passed to the parent YOLODataset class.\n \"\"\"\n super().__init__(*args, data=data, **kwargs)\n\n def load_image(self, i, rect_mode=False):\n \"\"\"\n Load one image from dataset index 'i'.\n\n Args:\n i (int): Index of the image to load.\n rect_mode (bool, optional): Whether to use rectangular mode for batch inference.\n\n Returns:\n im (torch.Tensor): The loaded image.\n resized_hw (tuple): Height and width of the resized image with shape (2,).\n\n Examples:\n Load an image from the dataset\n >>> dataset = RTDETRDataset(img_path=\"path/to/images\")\n >>> image, hw = dataset.load_image(0)\n \"\"\"\n return super().load_image(i=i, rect_mode=rect_mode)\n\n def build_transforms(self, hyp=None):\n \"\"\"\n Build transformation pipeline for the dataset.\n\n Args:\n hyp (dict, optional): Hyperparameters for transformations.\n\n Returns:\n (Compose): Composition of transformation functions.\n \"\"\"\n if self.augment:\n hyp.mosaic = hyp.mosaic if self.augment and not self.rect else 0.0\n hyp.mixup = hyp.mixup if self.augment and not self.rect else 0.0\n hyp.cutmix = hyp.cutmix if self.augment and not self.rect else 0.0\n transforms = v8_transforms(self, self.imgsz, hyp, stretch=True)\n else:\n # transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), auto=False, scale_fill=True)])\n transforms = Compose([])\n transforms.append(\n Format(\n bbox_format=\"xywh\",\n normalize=True,\n return_mask=self.use_segments,\n return_keypoint=self.use_keypoints,\n batch_idx=True,\n mask_ratio=hyp.mask_ratio,\n mask_overlap=hyp.overlap_mask,\n )\n )\n return transforms", "chunk_type": "class", "name": "RTDETRDataset", "file_path": "ultralytics\\ultralytics\\models\\rtdetr\\val.py", "start_line": 15, "end_line": 101, "start_col": 0, "end_col": 25, "parent_name": null, "docstring": "Real-Time DEtection and TRacking (RT-DETR) dataset class extending the base YOLODataset class.\n\nThis specialized dataset class is designed for use with the RT-DETR object detection model and is optimized for\nreal-time detection and tracking tasks.\n\nAttributes:\n augment (bool): Whether to apply data augmentation.\n rect (bool): Whether to use rectangular training.\n use_segments (bool): Whether to use segmentation masks.\n use_keypoints (bool): Whether to use keypoint annotations.\n imgsz (int): Target image size for training.\n\nMethods:\n load_image: Load one image from dataset index.\n build_transforms: Build transformation pipeline for the dataset.\n\nExamples:\n Initialize an RT-DETR dataset\n >>> dataset = RTDETRDataset(img_path=\"path/to/images\", imgsz=640)\n >>> image, hw = dataset.load_image(0)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "typing.Union", "torch", "ultralytics.data.YOLODataset", "ultralytics.data.augment.Compose", "ultralytics.data.augment.Format", "ultralytics.data.augment.v8_transforms", "ultralytics.models.yolo.detect.DetectionValidator", "ultralytics.utils.colorstr", "ultralytics.utils.ops", "YOLODataset" ], "chunk_id": "class_RTDETRDataset_6b7a2796" }, { "content": "class RTDETRValidator(DetectionValidator):\n \"\"\"\n RTDETRValidator extends the DetectionValidator class to provide validation capabilities specifically tailored for\n the RT-DETR (Real-Time DETR) object detection model.\n\n The class allows building of an RTDETR-specific dataset for validation, applies Non-maximum suppression for\n post-processing, and updates evaluation metrics accordingly.\n\n Attributes:\n args (Namespace): Configuration arguments for validation.\n data (dict): Dataset configuration dictionary.\n\n Methods:\n build_dataset: Build an RTDETR Dataset for validation.\n postprocess: Apply Non-maximum suppression to prediction outputs.\n\n Examples:\n Initialize and run RT-DETR validation\n >>> from ultralytics.models.rtdetr import RTDETRValidator\n >>> args = dict(model=\"rtdetr-l.pt\", data=\"coco8.yaml\")\n >>> validator = RTDETRValidator(args=args)\n >>> validator()\n\n Notes:\n For further details on the attributes and methods, refer to the parent DetectionValidator class.\n \"\"\"\n\n def build_dataset(self, img_path, mode=\"val\", batch=None):\n \"\"\"\n Build an RTDETR Dataset.\n\n Args:\n img_path (str): Path to the folder containing images.\n mode (str, optional): `train` mode or `val` mode, users are able to customize different augmentations for\n each mode.\n batch (int, optional): Size of batches, this is for `rect`.\n\n Returns:\n (RTDETRDataset): Dataset configured for RT-DETR validation.\n \"\"\"\n return RTDETRDataset(\n img_path=img_path,\n imgsz=self.args.imgsz,\n batch_size=batch,\n augment=False, # no augmentation\n hyp=self.args,\n rect=False, # no rect\n cache=self.args.cache or None,\n prefix=colorstr(f\"{mode}: \"),\n data=self.data,\n )\n\n def postprocess(\n self, preds: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]\n ) -> List[Dict[str, torch.Tensor]]:\n \"\"\"\n Apply Non-maximum suppression to prediction outputs.\n\n Args:\n preds (torch.Tensor | List | Tuple): Raw predictions from the model. If tensor, should have shape\n (batch_size, num_predictions, num_classes + 4) where last dimension contains bbox coords and class scores.\n\n Returns:\n (List[Dict[str, torch.Tensor]]): List of dictionaries for each image, each containing:\n - 'bboxes': Tensor of shape (N, 4) with bounding box coordinates\n - 'conf': Tensor of shape (N,) with confidence scores\n - 'cls': Tensor of shape (N,) with class indices\n \"\"\"\n if not isinstance(preds, (list, tuple)): # list for PyTorch inference but list[0] Tensor for export inference\n preds = [preds, None]\n\n bs, _, nd = preds[0].shape\n bboxes, scores = preds[0].split((4, nd - 4), dim=-1)\n bboxes *= self.args.imgsz\n outputs = [torch.zeros((0, 6), device=bboxes.device)] * bs\n for i, bbox in enumerate(bboxes): # (300, 4)\n bbox = ops.xywh2xyxy(bbox)\n score, cls = scores[i].max(-1) # (300, )\n pred = torch.cat([bbox, score[..., None], cls[..., None]], dim=-1) # filter\n # Sort by confidence to correctly get internal metrics\n pred = pred[score.argsort(descending=True)]\n outputs[i] = pred[score > self.args.conf]\n\n return [{\"bboxes\": x[:, :4], \"conf\": x[:, 4], \"cls\": x[:, 5]} for x in outputs]\n\n def _prepare_batch(self, si: int, batch: Dict[str, Any]) -> Dict[str, Any]:\n \"\"\"\n Prepare a batch for validation by applying necessary transformations.\n\n Args:\n si (int): Batch index.\n batch (Dict[str, Any]): Batch data containing images and annotations.\n\n Returns:\n (Dict[str, Any]): Prepared batch with transformed annotations containing cls, bboxes,\n ori_shape, imgsz, and ratio_pad.\n \"\"\"\n idx = batch[\"batch_idx\"] == si\n cls = batch[\"cls\"][idx].squeeze(-1)\n bbox = batch[\"bboxes\"][idx]\n ori_shape = batch[\"ori_shape\"][si]\n imgsz = batch[\"img\"].shape[2:]\n ratio_pad = batch[\"ratio_pad\"][si]\n if len(cls):\n bbox = ops.xywh2xyxy(bbox) # target boxes\n bbox[..., [0, 2]] *= ori_shape[1] # native-space pred\n bbox[..., [1, 3]] *= ori_shape[0] # native-space pred\n return {\"cls\": cls, \"bboxes\": bbox, \"ori_shape\": ori_shape, \"imgsz\": imgsz, \"ratio_pad\": ratio_pad}\n\n def _prepare_pred(self, pred: Dict[str, torch.Tensor], pbatch: Dict[str, Any]) -> Dict[str, torch.Tensor]:\n \"\"\"\n Prepare predictions by scaling bounding boxes to original image dimensions.\n\n Args:\n pred (Dict[str, torch.Tensor]): Raw predictions containing 'cls', 'bboxes', and 'conf'.\n pbatch (Dict[str, torch.Tensor]): Prepared batch information containing 'ori_shape' and other metadata.\n\n Returns:\n (Dict[str, torch.Tensor]): Predictions scaled to original image dimensions.\n \"\"\"\n cls = pred[\"cls\"]\n if self.args.single_cls:\n cls *= 0\n bboxes = pred[\"bboxes\"].clone()\n bboxes[..., [0, 2]] *= pbatch[\"ori_shape\"][1] / self.args.imgsz # native-space pred\n bboxes[..., [1, 3]] *= pbatch[\"ori_shape\"][0] / self.args.imgsz # native-space pred\n return {\"bboxes\": bboxes, \"conf\": pred[\"conf\"], \"cls\": cls}", "chunk_type": "class", "name": "RTDETRValidator", "file_path": "ultralytics\\ultralytics\\models\\rtdetr\\val.py", "start_line": 104, "end_line": 230, "start_col": 0, "end_col": 67, "parent_name": null, "docstring": "RTDETRValidator extends the DetectionValidator class to provide validation capabilities specifically tailored for\nthe RT-DETR (Real-Time DETR) object detection model.\n\nThe class allows building of an RTDETR-specific dataset for validation, applies Non-maximum suppression for\npost-processing, and updates evaluation metrics accordingly.\n\nAttributes:\n args (Namespace): Configuration arguments for validation.\n data (dict): Dataset configuration dictionary.\n\nMethods:\n build_dataset: Build an RTDETR Dataset for validation.\n postprocess: Apply Non-maximum suppression to prediction outputs.\n\nExamples:\n Initialize and run RT-DETR validation\n >>> from ultralytics.models.rtdetr import RTDETRValidator\n >>> args = dict(model=\"rtdetr-l.pt\", data=\"coco8.yaml\")\n >>> validator = RTDETRValidator(args=args)\n >>> validator()\n\nNotes:\n For further details on the attributes and methods, refer to the parent DetectionValidator class.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "typing.Union", "torch", "ultralytics.data.YOLODataset", "ultralytics.data.augment.Compose", "ultralytics.data.augment.Format", "ultralytics.data.augment.v8_transforms", "ultralytics.models.yolo.detect.DetectionValidator", "ultralytics.utils.colorstr", "ultralytics.utils.ops", "DetectionValidator" ], "chunk_id": "class_RTDETRValidator_cb87324e" }, { "content": "from .model import RTDETR", "chunk_type": "import", "name": "RTDETR", "file_path": "ultralytics\\ultralytics\\models\\rtdetr\\__init__.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 25, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_RTDETR_a9b0b4d3" }, { "content": "from .predict import RTDETRPredictor", "chunk_type": "import", "name": "RTDETRPredictor", "file_path": "ultralytics\\ultralytics\\models\\rtdetr\\__init__.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 36, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_RTDETRPredictor_47e1edf7" }, { "content": "from .val import RTDETRValidator", "chunk_type": "import", "name": "RTDETRValidator", "file_path": "ultralytics\\ultralytics\\models\\rtdetr\\__init__.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 32, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_RTDETRValidator_8974b27e" }, { "content": "__all__ = \"RTDETRPredictor\", \"RTDETRValidator\", \"RTDETR\"", "chunk_type": "variable", "name": "__all__", "file_path": "ultralytics\\ultralytics\\models\\rtdetr\\__init__.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 56, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable___all___cb0272bf" }, { "content": "import math", "chunk_type": "import", "name": "math", "file_path": "ultralytics\\ultralytics\\models\\sam\\amg.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 11, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_math_de6c83bb" }, { "content": "from itertools import product", "chunk_type": "import", "name": "product", "file_path": "ultralytics\\ultralytics\\models\\sam\\amg.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 29, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_product_17266aff" }, { "content": "from typing import Any, Generator, List, Tuple", "chunk_type": "import", "name": "Any, Generator, List, Tuple", "file_path": "ultralytics\\ultralytics\\models\\sam\\amg.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 46, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Any, Generator, List, Tuple_49540f0a" }, { "content": "import numpy as np", "chunk_type": "import", "name": "numpy", "file_path": "ultralytics\\ultralytics\\models\\sam\\amg.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_numpy_54d8819c" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\ultralytics\\models\\sam\\amg.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_b4c550d5" }, { "content": "def is_box_near_crop_edge(\n boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0\n) -> torch.Tensor:\n \"\"\"\n Determine if bounding boxes are near the edge of a cropped image region using a specified tolerance.\n\n Args:\n boxes (torch.Tensor): Bounding boxes in XYXY format.\n crop_box (List[int]): Crop box coordinates in [x0, y0, x1, y1] format.\n orig_box (List[int]): Original image box coordinates in [x0, y0, x1, y1] format.\n atol (float, optional): Absolute tolerance for edge proximity detection.\n\n Returns:\n (torch.Tensor): Boolean tensor indicating which boxes are near crop edges.\n\n Examples:\n >>> boxes = torch.tensor([[10, 10, 50, 50], [100, 100, 150, 150]])\n >>> crop_box = [0, 0, 200, 200]\n >>> orig_box = [0, 0, 300, 300]\n >>> near_edge = is_box_near_crop_edge(boxes, crop_box, orig_box, atol=20.0)\n \"\"\"\n crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)\n orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)\n boxes = uncrop_boxes_xyxy(boxes, crop_box).float()\n near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)\n near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)\n near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)\n return torch.any(near_crop_edge, dim=1)", "chunk_type": "function", "name": "is_box_near_crop_edge", "file_path": "ultralytics\\ultralytics\\models\\sam\\amg.py", "start_line": 11, "end_line": 38, "start_col": 0, "end_col": 43, "parent_name": null, "docstring": "Determine if bounding boxes are near the edge of a cropped image region using a specified tolerance.\n\nArgs:\n boxes (torch.Tensor): Bounding boxes in XYXY format.\n crop_box (List[int]): Crop box coordinates in [x0, y0, x1, y1] format.\n orig_box (List[int]): Original image box coordinates in [x0, y0, x1, y1] format.\n atol (float, optional): Absolute tolerance for edge proximity detection.\n\nReturns:\n (torch.Tensor): Boolean tensor indicating which boxes are near crop edges.\n\nExamples:\n >>> boxes = torch.tensor([[10, 10, 50, 50], [100, 100, 150, 150]])\n >>> crop_box = [0, 0, 200, 200]\n >>> orig_box = [0, 0, 300, 300]\n >>> near_edge = is_box_near_crop_edge(boxes, crop_box, orig_box, atol=20.0)", "parameters": [ "boxes: torch.Tensor", "crop_box: List[int]", "orig_box: List[int]", "atol: float" ], "return_type": "torch.Tensor", "decorators": [], "complexity_score": 1, "dependencies": [ "math", "itertools.product", "typing.Any", "typing.Generator", "typing.List", "typing.Tuple", "numpy", "torch", "cv2" ], "chunk_id": "function_is_box_near_crop_edge_d6d98eba" }, { "content": "def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:\n \"\"\"\n Yield batches of data from input arguments with specified batch size for efficient processing.\n\n This function takes a batch size and any number of iterables, then yields batches of elements from those\n iterables. All input iterables must have the same length.\n\n Args:\n batch_size (int): Size of each batch to yield.\n *args (Any): Variable length input iterables to batch. All iterables must have the same length.\n\n Yields:\n (List[Any]): A list of batched elements from each input iterable.\n\n Examples:\n >>> data = [1, 2, 3, 4, 5]\n >>> labels = [\"a\", \"b\", \"c\", \"d\", \"e\"]\n >>> for batch in batch_iterator(2, data, labels):\n ... print(batch)\n [[1, 2], ['a', 'b']]\n [[3, 4], ['c', 'd']]\n [[5], ['e']]\n \"\"\"\n assert args and all(len(a) == len(args[0]) for a in args), \"Batched iteration must have same-size inputs.\"\n n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)\n for b in range(n_batches):\n yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]", "chunk_type": "function", "name": "batch_iterator", "file_path": "ultralytics\\ultralytics\\models\\sam\\amg.py", "start_line": 41, "end_line": 67, "start_col": 0, "end_col": 74, "parent_name": null, "docstring": "Yield batches of data from input arguments with specified batch size for efficient processing.\n\nThis function takes a batch size and any number of iterables, then yields batches of elements from those\niterables. All input iterables must have the same length.\n\nArgs:\n batch_size (int): Size of each batch to yield.\n *args (Any): Variable length input iterables to batch. All iterables must have the same length.\n\nYields:\n (List[Any]): A list of batched elements from each input iterable.\n\nExamples:\n >>> data = [1, 2, 3, 4, 5]\n >>> labels = [\"a\", \"b\", \"c\", \"d\", \"e\"]\n >>> for batch in batch_iterator(2, data, labels):\n ... print(batch)\n [[1, 2], ['a', 'b']]\n [[3, 4], ['c', 'd']]\n [[5], ['e']]", "parameters": [ "batch_size: int" ], "return_type": "Generator[List[Any], None, None]", "decorators": [], "complexity_score": 4, "dependencies": [ "math", "itertools.product", "typing.Any", "typing.Generator", "typing.List", "typing.Tuple", "numpy", "torch", "cv2" ], "chunk_id": "function_batch_iterator_e43338cb" }, { "content": "def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, threshold_offset: float) -> torch.Tensor:\n \"\"\"\n Compute the stability score for a batch of masks.\n\n The stability score is the IoU between binary masks obtained by thresholding the predicted mask logits at\n high and low values.\n\n Args:\n masks (torch.Tensor): Batch of predicted mask logits.\n mask_threshold (float): Threshold value for creating binary masks.\n threshold_offset (float): Offset applied to the threshold for creating high and low binary masks.\n\n Returns:\n (torch.Tensor): Stability scores for each mask in the batch.\n\n Notes:\n - One mask is always contained inside the other.\n - Memory is saved by preventing unnecessary cast to torch.int64.\n\n Examples:\n >>> masks = torch.rand(10, 256, 256) # Batch of 10 masks\n >>> mask_threshold = 0.5\n >>> threshold_offset = 0.1\n >>> stability_scores = calculate_stability_score(masks, mask_threshold, threshold_offset)\n \"\"\"\n intersections = (masks > (mask_threshold + threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)\n unions = (masks > (mask_threshold - threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)\n return intersections / unions", "chunk_type": "function", "name": "calculate_stability_score", "file_path": "ultralytics\\ultralytics\\models\\sam\\amg.py", "start_line": 70, "end_line": 97, "start_col": 0, "end_col": 33, "parent_name": null, "docstring": "Compute the stability score for a batch of masks.\n\nThe stability score is the IoU between binary masks obtained by thresholding the predicted mask logits at\nhigh and low values.\n\nArgs:\n masks (torch.Tensor): Batch of predicted mask logits.\n mask_threshold (float): Threshold value for creating binary masks.\n threshold_offset (float): Offset applied to the threshold for creating high and low binary masks.\n\nReturns:\n (torch.Tensor): Stability scores for each mask in the batch.\n\nNotes:\n - One mask is always contained inside the other.\n - Memory is saved by preventing unnecessary cast to torch.int64.\n\nExamples:\n >>> masks = torch.rand(10, 256, 256) # Batch of 10 masks\n >>> mask_threshold = 0.5\n >>> threshold_offset = 0.1\n >>> stability_scores = calculate_stability_score(masks, mask_threshold, threshold_offset)", "parameters": [ "masks: torch.Tensor", "mask_threshold: float", "threshold_offset: float" ], "return_type": "torch.Tensor", "decorators": [], "complexity_score": 1, "dependencies": [ "math", "itertools.product", "typing.Any", "typing.Generator", "typing.List", "typing.Tuple", "numpy", "torch", "cv2" ], "chunk_id": "function_calculate_stability_score_eb844f5d" }, { "content": "def build_point_grid(n_per_side: int) -> np.ndarray:\n \"\"\"Generate a 2D grid of evenly spaced points in the range [0,1]x[0,1] for image segmentation tasks.\"\"\"\n offset = 1 / (2 * n_per_side)\n points_one_side = np.linspace(offset, 1 - offset, n_per_side)\n points_x = np.tile(points_one_side[None, :], (n_per_side, 1))\n points_y = np.tile(points_one_side[:, None], (1, n_per_side))\n return np.stack([points_x, points_y], axis=-1).reshape(-1, 2)", "chunk_type": "function", "name": "build_point_grid", "file_path": "ultralytics\\ultralytics\\models\\sam\\amg.py", "start_line": 100, "end_line": 106, "start_col": 0, "end_col": 65, "parent_name": null, "docstring": "Generate a 2D grid of evenly spaced points in the range [0,1]x[0,1] for image segmentation tasks.", "parameters": [ "n_per_side: int" ], "return_type": "np.ndarray", "decorators": [], "complexity_score": 1, "dependencies": [ "math", "itertools.product", "typing.Any", "typing.Generator", "typing.List", "typing.Tuple", "numpy", "torch", "cv2" ], "chunk_id": "function_build_point_grid_0b97c27c" }, { "content": "def build_all_layer_point_grids(n_per_side: int, n_layers: int, scale_per_layer: int) -> List[np.ndarray]:\n \"\"\"Generate point grids for multiple crop layers with varying scales and densities.\"\"\"\n return [build_point_grid(int(n_per_side / (scale_per_layer**i))) for i in range(n_layers + 1)]", "chunk_type": "function", "name": "build_all_layer_point_grids", "file_path": "ultralytics\\ultralytics\\models\\sam\\amg.py", "start_line": 109, "end_line": 111, "start_col": 0, "end_col": 98, "parent_name": null, "docstring": "Generate point grids for multiple crop layers with varying scales and densities.", "parameters": [ "n_per_side: int", "n_layers: int", "scale_per_layer: int" ], "return_type": "List[np.ndarray]", "decorators": [], "complexity_score": 2, "dependencies": [ "math", "itertools.product", "typing.Any", "typing.Generator", "typing.List", "typing.Tuple", "numpy", "torch", "cv2" ], "chunk_id": "function_build_all_layer_point_grids_e3ffef29" }, { "content": "def generate_crop_boxes(\n im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float\n) -> Tuple[List[List[int]], List[int]]:\n \"\"\"\n Generate crop boxes of varying sizes for multiscale image processing, with layered overlapping regions.\n\n Args:\n im_size (Tuple[int, ...]): Height and width of the input image.\n n_layers (int): Number of layers to generate crop boxes for.\n overlap_ratio (float): Ratio of overlap between adjacent crop boxes.\n\n Returns:\n crop_boxes (List[List[int]]): List of crop boxes in [x0, y0, x1, y1] format.\n layer_idxs (List[int]): List of layer indices corresponding to each crop box.\n\n Examples:\n >>> im_size = (800, 1200) # Height, width\n >>> n_layers = 3\n >>> overlap_ratio = 0.25\n >>> crop_boxes, layer_idxs = generate_crop_boxes(im_size, n_layers, overlap_ratio)\n \"\"\"\n crop_boxes, layer_idxs = [], []\n im_h, im_w = im_size\n short_side = min(im_h, im_w)\n\n # Original image\n crop_boxes.append([0, 0, im_w, im_h])\n layer_idxs.append(0)\n\n def crop_len(orig_len, n_crops, overlap):\n \"\"\"Calculate the length of each crop given the original length, number of crops, and overlap.\"\"\"\n return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops))\n\n for i_layer in range(n_layers):\n n_crops_per_side = 2 ** (i_layer + 1)\n overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))\n\n crop_w = crop_len(im_w, n_crops_per_side, overlap)\n crop_h = crop_len(im_h, n_crops_per_side, overlap)\n\n crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)]\n crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)]\n\n # Crops in XYWH format\n for x0, y0 in product(crop_box_x0, crop_box_y0):\n box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)]\n crop_boxes.append(box)\n layer_idxs.append(i_layer + 1)\n\n return crop_boxes, layer_idxs", "chunk_type": "function", "name": "generate_crop_boxes", "file_path": "ultralytics\\ultralytics\\models\\sam\\amg.py", "start_line": 114, "end_line": 163, "start_col": 0, "end_col": 33, "parent_name": null, "docstring": "Generate crop boxes of varying sizes for multiscale image processing, with layered overlapping regions.\n\nArgs:\n im_size (Tuple[int, ...]): Height and width of the input image.\n n_layers (int): Number of layers to generate crop boxes for.\n overlap_ratio (float): Ratio of overlap between adjacent crop boxes.\n\nReturns:\n crop_boxes (List[List[int]]): List of crop boxes in [x0, y0, x1, y1] format.\n layer_idxs (List[int]): List of layer indices corresponding to each crop box.\n\nExamples:\n >>> im_size = (800, 1200) # Height, width\n >>> n_layers = 3\n >>> overlap_ratio = 0.25\n >>> crop_boxes, layer_idxs = generate_crop_boxes(im_size, n_layers, overlap_ratio)", "parameters": [ "im_size: Tuple[int, ...]", "n_layers: int", "overlap_ratio: float" ], "return_type": "Tuple[List[List[int]], List[int]]", "decorators": [], "complexity_score": 5, "dependencies": [ "math", "itertools.product", "typing.Any", "typing.Generator", "typing.List", "typing.Tuple", "numpy", "torch", "cv2" ], "chunk_id": "function_generate_crop_boxes_9af763b4" }, { "content": "def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor:\n \"\"\"Uncrop bounding boxes by adding the crop box offset to their coordinates.\"\"\"\n x0, y0, _, _ = crop_box\n offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device)\n # Check if boxes has a channel dimension\n if len(boxes.shape) == 3:\n offset = offset.unsqueeze(1)\n return boxes + offset", "chunk_type": "function", "name": "uncrop_boxes_xyxy", "file_path": "ultralytics\\ultralytics\\models\\sam\\amg.py", "start_line": 166, "end_line": 173, "start_col": 0, "end_col": 25, "parent_name": null, "docstring": "Uncrop bounding boxes by adding the crop box offset to their coordinates.", "parameters": [ "boxes: torch.Tensor", "crop_box: List[int]" ], "return_type": "torch.Tensor", "decorators": [], "complexity_score": 2, "dependencies": [ "math", "itertools.product", "typing.Any", "typing.Generator", "typing.List", "typing.Tuple", "numpy", "torch", "cv2" ], "chunk_id": "function_uncrop_boxes_xyxy_a4e9fac7" }, { "content": "def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:\n \"\"\"Uncrop points by adding the crop box offset to their coordinates.\"\"\"\n x0, y0, _, _ = crop_box\n offset = torch.tensor([[x0, y0]], device=points.device)\n # Check if points has a channel dimension\n if len(points.shape) == 3:\n offset = offset.unsqueeze(1)\n return points + offset", "chunk_type": "function", "name": "uncrop_points", "file_path": "ultralytics\\ultralytics\\models\\sam\\amg.py", "start_line": 176, "end_line": 183, "start_col": 0, "end_col": 26, "parent_name": null, "docstring": "Uncrop points by adding the crop box offset to their coordinates.", "parameters": [ "points: torch.Tensor", "crop_box: List[int]" ], "return_type": "torch.Tensor", "decorators": [], "complexity_score": 2, "dependencies": [ "math", "itertools.product", "typing.Any", "typing.Generator", "typing.List", "typing.Tuple", "numpy", "torch", "cv2" ], "chunk_id": "function_uncrop_points_eac1907e" }, { "content": "def uncrop_masks(masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int) -> torch.Tensor:\n \"\"\"Uncrop masks by padding them to the original image size, handling coordinate transformations.\"\"\"\n x0, y0, x1, y1 = crop_box\n if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:\n return masks\n # Coordinate transform masks\n pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0)\n pad = (x0, pad_x - x0, y0, pad_y - y0)\n return torch.nn.functional.pad(masks, pad, value=0)", "chunk_type": "function", "name": "uncrop_masks", "file_path": "ultralytics\\ultralytics\\models\\sam\\amg.py", "start_line": 186, "end_line": 194, "start_col": 0, "end_col": 55, "parent_name": null, "docstring": "Uncrop masks by padding them to the original image size, handling coordinate transformations.", "parameters": [ "masks: torch.Tensor", "crop_box: List[int]", "orig_h: int", "orig_w: int" ], "return_type": "torch.Tensor", "decorators": [], "complexity_score": 2, "dependencies": [ "math", "itertools.product", "typing.Any", "typing.Generator", "typing.List", "typing.Tuple", "numpy", "torch", "cv2" ], "chunk_id": "function_uncrop_masks_5b18c7a2" }, { "content": "def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tuple[np.ndarray, bool]:\n \"\"\"\n Remove small disconnected regions or holes in a mask based on area threshold and mode.\n\n Args:\n mask (np.ndarray): Binary mask to process.\n area_thresh (float): Area threshold below which regions will be removed.\n mode (str): Processing mode, either 'holes' to fill small holes or 'islands' to remove small disconnected\n regions.\n\n Returns:\n processed_mask (np.ndarray): Processed binary mask with small regions removed.\n modified (bool): Whether any regions were modified.\n\n Examples:\n >>> mask = np.zeros((100, 100), dtype=np.bool_)\n >>> mask[40:60, 40:60] = True # Create a square\n >>> mask[45:55, 45:55] = False # Create a hole\n >>> processed_mask, modified = remove_small_regions(mask, 50, \"holes\")\n \"\"\"\n import cv2 # type: ignore\n\n assert mode in {\"holes\", \"islands\"}, f\"Provided mode {mode} is invalid\"\n correct_holes = mode == \"holes\"\n working_mask = (correct_holes ^ mask).astype(np.uint8)\n n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)\n sizes = stats[:, -1][1:] # Row 0 is background label\n small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]\n if not small_regions:\n return mask, False\n fill_labels = [0] + small_regions\n if not correct_holes:\n # If every region is below threshold, keep largest\n fill_labels = [i for i in range(n_labels) if i not in fill_labels] or [int(np.argmax(sizes)) + 1]\n mask = np.isin(regions, fill_labels)\n return mask, True", "chunk_type": "function", "name": "remove_small_regions", "file_path": "ultralytics\\ultralytics\\models\\sam\\amg.py", "start_line": 197, "end_line": 232, "start_col": 0, "end_col": 21, "parent_name": null, "docstring": "Remove small disconnected regions or holes in a mask based on area threshold and mode.\n\nArgs:\n mask (np.ndarray): Binary mask to process.\n area_thresh (float): Area threshold below which regions will be removed.\n mode (str): Processing mode, either 'holes' to fill small holes or 'islands' to remove small disconnected\n regions.\n\nReturns:\n processed_mask (np.ndarray): Processed binary mask with small regions removed.\n modified (bool): Whether any regions were modified.\n\nExamples:\n >>> mask = np.zeros((100, 100), dtype=np.bool_)\n >>> mask[40:60, 40:60] = True # Create a square\n >>> mask[45:55, 45:55] = False # Create a hole\n >>> processed_mask, modified = remove_small_regions(mask, 50, \"holes\")", "parameters": [ "mask: np.ndarray", "area_thresh: float", "mode: str" ], "return_type": "Tuple[np.ndarray, bool]", "decorators": [], "complexity_score": 5, "dependencies": [ "math", "itertools.product", "typing.Any", "typing.Generator", "typing.List", "typing.Tuple", "numpy", "torch", "cv2" ], "chunk_id": "function_remove_small_regions_8f6fa57b" }, { "content": "def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:\n \"\"\"\n Calculate bounding boxes in XYXY format around binary masks.\n\n Args:\n masks (torch.Tensor): Binary masks with shape (B, H, W) or (B, C, H, W).\n\n Returns:\n (torch.Tensor): Bounding boxes in XYXY format with shape (B, 4) or (B, C, 4).\n\n Notes:\n - Handles empty masks by returning zero boxes.\n - Preserves input tensor dimensions in the output.\n \"\"\"\n # torch.max below raises an error on empty inputs, just skip in this case\n if torch.numel(masks) == 0:\n return torch.zeros(*masks.shape[:-2], 4, device=masks.device)\n\n # Normalize shape to CxHxW\n shape = masks.shape\n h, w = shape[-2:]\n masks = masks.flatten(0, -3) if len(shape) > 2 else masks.unsqueeze(0)\n # Get top and bottom edges\n in_height, _ = torch.max(masks, dim=-1)\n in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :]\n bottom_edges, _ = torch.max(in_height_coords, dim=-1)\n in_height_coords = in_height_coords + h * (~in_height)\n top_edges, _ = torch.min(in_height_coords, dim=-1)\n\n # Get left and right edges\n in_width, _ = torch.max(masks, dim=-2)\n in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :]\n right_edges, _ = torch.max(in_width_coords, dim=-1)\n in_width_coords = in_width_coords + w * (~in_width)\n left_edges, _ = torch.min(in_width_coords, dim=-1)\n\n # If the mask is empty the right edge will be to the left of the left edge.\n # Replace these boxes with [0, 0, 0, 0]\n empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)\n out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)\n out = out * (~empty_filter).unsqueeze(-1)\n\n # Return to original shape\n return out.reshape(*shape[:-2], 4) if len(shape) > 2 else out[0]", "chunk_type": "function", "name": "batched_mask_to_box", "file_path": "ultralytics\\ultralytics\\models\\sam\\amg.py", "start_line": 235, "end_line": 278, "start_col": 0, "end_col": 68, "parent_name": null, "docstring": "Calculate bounding boxes in XYXY format around binary masks.\n\nArgs:\n masks (torch.Tensor): Binary masks with shape (B, H, W) or (B, C, H, W).\n\nReturns:\n (torch.Tensor): Bounding boxes in XYXY format with shape (B, 4) or (B, C, 4).\n\nNotes:\n - Handles empty masks by returning zero boxes.\n - Preserves input tensor dimensions in the output.", "parameters": [ "masks: torch.Tensor" ], "return_type": "torch.Tensor", "decorators": [], "complexity_score": 2, "dependencies": [ "math", "itertools.product", "typing.Any", "typing.Generator", "typing.List", "typing.Tuple", "numpy", "torch", "cv2" ], "chunk_id": "function_batched_mask_to_box_e2200aae" }, { "content": "from functools import partial", "chunk_type": "import", "name": "partial", "file_path": "ultralytics\\ultralytics\\models\\sam\\build.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 29, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_partial_8b42cab9" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\ultralytics\\models\\sam\\build.py", "start_line": 11, "end_line": 11, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_91b55564" }, { "content": "from ultralytics.utils.downloads import attempt_download_asset", "chunk_type": "import", "name": "attempt_download_asset", "file_path": "ultralytics\\ultralytics\\models\\sam\\build.py", "start_line": 13, "end_line": 13, "start_col": 0, "end_col": 62, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_attempt_download_asset_1a278421" }, { "content": "from .modules.decoders import MaskDecoder", "chunk_type": "import", "name": "MaskDecoder", "file_path": "ultralytics\\ultralytics\\models\\sam\\build.py", "start_line": 15, "end_line": 15, "start_col": 0, "end_col": 41, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_MaskDecoder_37076adc" }, { "content": "from .modules.encoders import FpnNeck, Hiera, ImageEncoder, ImageEncoderViT, MemoryEncoder, PromptEncoder", "chunk_type": "import", "name": "FpnNeck, Hiera, ImageEncoder, ImageEncoderViT, MemoryEncoder, PromptEncoder", "file_path": "ultralytics\\ultralytics\\models\\sam\\build.py", "start_line": 16, "end_line": 16, "start_col": 0, "end_col": 105, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_FpnNeck, Hiera, ImageEncoder, ImageEncoderViT, MemoryEncoder, PromptEncoder_09f926dc" }, { "content": "from .modules.memory_attention import MemoryAttention, MemoryAttentionLayer", "chunk_type": "import", "name": "MemoryAttention, MemoryAttentionLayer", "file_path": "ultralytics\\ultralytics\\models\\sam\\build.py", "start_line": 17, "end_line": 17, "start_col": 0, "end_col": 75, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_MemoryAttention, MemoryAttentionLayer_748598b9" }, { "content": "from .modules.sam import SAM2Model, SAMModel", "chunk_type": "import", "name": "SAM2Model, SAMModel", "file_path": "ultralytics\\ultralytics\\models\\sam\\build.py", "start_line": 18, "end_line": 18, "start_col": 0, "end_col": 44, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_SAM2Model, SAMModel_41e150c1" }, { "content": "from .modules.tiny_encoder import TinyViT", "chunk_type": "import", "name": "TinyViT", "file_path": "ultralytics\\ultralytics\\models\\sam\\build.py", "start_line": 19, "end_line": 19, "start_col": 0, "end_col": 41, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_TinyViT_d67c140d" }, { "content": "from .modules.transformer import TwoWayTransformer", "chunk_type": "import", "name": "TwoWayTransformer", "file_path": "ultralytics\\ultralytics\\models\\sam\\build.py", "start_line": 20, "end_line": 20, "start_col": 0, "end_col": 50, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_TwoWayTransformer_4299f81a" }, { "content": "def build_sam_vit_h(checkpoint=None):\n \"\"\"Build and return a Segment Anything Model (SAM) h-size model with specified encoder parameters.\"\"\"\n return _build_sam(\n encoder_embed_dim=1280,\n encoder_depth=32,\n encoder_num_heads=16,\n encoder_global_attn_indexes=[7, 15, 23, 31],\n checkpoint=checkpoint,\n )", "chunk_type": "function", "name": "build_sam_vit_h", "file_path": "ultralytics\\ultralytics\\models\\sam\\build.py", "start_line": 23, "end_line": 31, "start_col": 0, "end_col": 5, "parent_name": null, "docstring": "Build and return a Segment Anything Model (SAM) h-size model with specified encoder parameters.", "parameters": [ "checkpoint" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "functools.partial", "torch", "ultralytics.utils.downloads.attempt_download_asset", "modules.decoders.MaskDecoder", "modules.encoders.FpnNeck", "modules.encoders.Hiera", "modules.encoders.ImageEncoder", "modules.encoders.ImageEncoderViT", "modules.encoders.MemoryEncoder", "modules.encoders.PromptEncoder", "modules.memory_attention.MemoryAttention", "modules.memory_attention.MemoryAttentionLayer", "modules.sam.SAM2Model", "modules.sam.SAMModel", "modules.tiny_encoder.TinyViT", "modules.transformer.TwoWayTransformer" ], "chunk_id": "function_build_sam_vit_h_b44b4a26" }, { "content": "def build_sam_vit_l(checkpoint=None):\n \"\"\"Build and return a Segment Anything Model (SAM) l-size model with specified encoder parameters.\"\"\"\n return _build_sam(\n encoder_embed_dim=1024,\n encoder_depth=24,\n encoder_num_heads=16,\n encoder_global_attn_indexes=[5, 11, 17, 23],\n checkpoint=checkpoint,\n )", "chunk_type": "function", "name": "build_sam_vit_l", "file_path": "ultralytics\\ultralytics\\models\\sam\\build.py", "start_line": 34, "end_line": 42, "start_col": 0, "end_col": 5, "parent_name": null, "docstring": "Build and return a Segment Anything Model (SAM) l-size model with specified encoder parameters.", "parameters": [ "checkpoint" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "functools.partial", "torch", "ultralytics.utils.downloads.attempt_download_asset", "modules.decoders.MaskDecoder", "modules.encoders.FpnNeck", "modules.encoders.Hiera", "modules.encoders.ImageEncoder", "modules.encoders.ImageEncoderViT", "modules.encoders.MemoryEncoder", "modules.encoders.PromptEncoder", "modules.memory_attention.MemoryAttention", "modules.memory_attention.MemoryAttentionLayer", "modules.sam.SAM2Model", "modules.sam.SAMModel", "modules.tiny_encoder.TinyViT", "modules.transformer.TwoWayTransformer" ], "chunk_id": "function_build_sam_vit_l_468e3fc6" }, { "content": "def build_sam_vit_b(checkpoint=None):\n \"\"\"Build and return a Segment Anything Model (SAM) b-size model with specified encoder parameters.\"\"\"\n return _build_sam(\n encoder_embed_dim=768,\n encoder_depth=12,\n encoder_num_heads=12,\n encoder_global_attn_indexes=[2, 5, 8, 11],\n checkpoint=checkpoint,\n )", "chunk_type": "function", "name": "build_sam_vit_b", "file_path": "ultralytics\\ultralytics\\models\\sam\\build.py", "start_line": 45, "end_line": 53, "start_col": 0, "end_col": 5, "parent_name": null, "docstring": "Build and return a Segment Anything Model (SAM) b-size model with specified encoder parameters.", "parameters": [ "checkpoint" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "functools.partial", "torch", "ultralytics.utils.downloads.attempt_download_asset", "modules.decoders.MaskDecoder", "modules.encoders.FpnNeck", "modules.encoders.Hiera", "modules.encoders.ImageEncoder", "modules.encoders.ImageEncoderViT", "modules.encoders.MemoryEncoder", "modules.encoders.PromptEncoder", "modules.memory_attention.MemoryAttention", "modules.memory_attention.MemoryAttentionLayer", "modules.sam.SAM2Model", "modules.sam.SAMModel", "modules.tiny_encoder.TinyViT", "modules.transformer.TwoWayTransformer" ], "chunk_id": "function_build_sam_vit_b_1f501484" }, { "content": "def build_mobile_sam(checkpoint=None):\n \"\"\"Build and return a Mobile Segment Anything Model (Mobile-SAM) for efficient image segmentation.\"\"\"\n return _build_sam(\n encoder_embed_dim=[64, 128, 160, 320],\n encoder_depth=[2, 2, 6, 2],\n encoder_num_heads=[2, 4, 5, 10],\n encoder_global_attn_indexes=None,\n mobile_sam=True,\n checkpoint=checkpoint,\n )", "chunk_type": "function", "name": "build_mobile_sam", "file_path": "ultralytics\\ultralytics\\models\\sam\\build.py", "start_line": 56, "end_line": 65, "start_col": 0, "end_col": 5, "parent_name": null, "docstring": "Build and return a Mobile Segment Anything Model (Mobile-SAM) for efficient image segmentation.", "parameters": [ "checkpoint" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "functools.partial", "torch", "ultralytics.utils.downloads.attempt_download_asset", "modules.decoders.MaskDecoder", "modules.encoders.FpnNeck", "modules.encoders.Hiera", "modules.encoders.ImageEncoder", "modules.encoders.ImageEncoderViT", "modules.encoders.MemoryEncoder", "modules.encoders.PromptEncoder", "modules.memory_attention.MemoryAttention", "modules.memory_attention.MemoryAttentionLayer", "modules.sam.SAM2Model", "modules.sam.SAMModel", "modules.tiny_encoder.TinyViT", "modules.transformer.TwoWayTransformer" ], "chunk_id": "function_build_mobile_sam_f102aafa" }, { "content": "def build_sam2_t(checkpoint=None):\n \"\"\"Build and return a Segment Anything Model 2 (SAM2) tiny-size model with specified architecture parameters.\"\"\"\n return _build_sam2(\n encoder_embed_dim=96,\n encoder_stages=[1, 2, 7, 2],\n encoder_num_heads=1,\n encoder_global_att_blocks=[5, 7, 9],\n encoder_window_spec=[8, 4, 14, 7],\n encoder_backbone_channel_list=[768, 384, 192, 96],\n checkpoint=checkpoint,\n )", "chunk_type": "function", "name": "build_sam2_t", "file_path": "ultralytics\\ultralytics\\models\\sam\\build.py", "start_line": 68, "end_line": 78, "start_col": 0, "end_col": 5, "parent_name": null, "docstring": "Build and return a Segment Anything Model 2 (SAM2) tiny-size model with specified architecture parameters.", "parameters": [ "checkpoint" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "functools.partial", "torch", "ultralytics.utils.downloads.attempt_download_asset", "modules.decoders.MaskDecoder", "modules.encoders.FpnNeck", "modules.encoders.Hiera", "modules.encoders.ImageEncoder", "modules.encoders.ImageEncoderViT", "modules.encoders.MemoryEncoder", "modules.encoders.PromptEncoder", "modules.memory_attention.MemoryAttention", "modules.memory_attention.MemoryAttentionLayer", "modules.sam.SAM2Model", "modules.sam.SAMModel", "modules.tiny_encoder.TinyViT", "modules.transformer.TwoWayTransformer" ], "chunk_id": "function_build_sam2_t_9c0c8531" }, { "content": "def build_sam2_s(checkpoint=None):\n \"\"\"Build and return a small-size Segment Anything Model 2 (SAM2) with specified architecture parameters.\"\"\"\n return _build_sam2(\n encoder_embed_dim=96,\n encoder_stages=[1, 2, 11, 2],\n encoder_num_heads=1,\n encoder_global_att_blocks=[7, 10, 13],\n encoder_window_spec=[8, 4, 14, 7],\n encoder_backbone_channel_list=[768, 384, 192, 96],\n checkpoint=checkpoint,\n )", "chunk_type": "function", "name": "build_sam2_s", "file_path": "ultralytics\\ultralytics\\models\\sam\\build.py", "start_line": 81, "end_line": 91, "start_col": 0, "end_col": 5, "parent_name": null, "docstring": "Build and return a small-size Segment Anything Model 2 (SAM2) with specified architecture parameters.", "parameters": [ "checkpoint" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "functools.partial", "torch", "ultralytics.utils.downloads.attempt_download_asset", "modules.decoders.MaskDecoder", "modules.encoders.FpnNeck", "modules.encoders.Hiera", "modules.encoders.ImageEncoder", "modules.encoders.ImageEncoderViT", "modules.encoders.MemoryEncoder", "modules.encoders.PromptEncoder", "modules.memory_attention.MemoryAttention", "modules.memory_attention.MemoryAttentionLayer", "modules.sam.SAM2Model", "modules.sam.SAMModel", "modules.tiny_encoder.TinyViT", "modules.transformer.TwoWayTransformer" ], "chunk_id": "function_build_sam2_s_1919150f" }, { "content": "def build_sam2_b(checkpoint=None):\n \"\"\"Build and return a Segment Anything Model 2 (SAM2) base-size model with specified architecture parameters.\"\"\"\n return _build_sam2(\n encoder_embed_dim=112,\n encoder_stages=[2, 3, 16, 3],\n encoder_num_heads=2,\n encoder_global_att_blocks=[12, 16, 20],\n encoder_window_spec=[8, 4, 14, 7],\n encoder_window_spatial_size=[14, 14],\n encoder_backbone_channel_list=[896, 448, 224, 112],\n checkpoint=checkpoint,\n )", "chunk_type": "function", "name": "build_sam2_b", "file_path": "ultralytics\\ultralytics\\models\\sam\\build.py", "start_line": 94, "end_line": 105, "start_col": 0, "end_col": 5, "parent_name": null, "docstring": "Build and return a Segment Anything Model 2 (SAM2) base-size model with specified architecture parameters.", "parameters": [ "checkpoint" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "functools.partial", "torch", "ultralytics.utils.downloads.attempt_download_asset", "modules.decoders.MaskDecoder", "modules.encoders.FpnNeck", "modules.encoders.Hiera", "modules.encoders.ImageEncoder", "modules.encoders.ImageEncoderViT", "modules.encoders.MemoryEncoder", "modules.encoders.PromptEncoder", "modules.memory_attention.MemoryAttention", "modules.memory_attention.MemoryAttentionLayer", "modules.sam.SAM2Model", "modules.sam.SAMModel", "modules.tiny_encoder.TinyViT", "modules.transformer.TwoWayTransformer" ], "chunk_id": "function_build_sam2_b_ad8831d6" }, { "content": "def build_sam2_l(checkpoint=None):\n \"\"\"Build and return a large-size Segment Anything Model 2 (SAM2) with specified architecture parameters.\"\"\"\n return _build_sam2(\n encoder_embed_dim=144,\n encoder_stages=[2, 6, 36, 4],\n encoder_num_heads=2,\n encoder_global_att_blocks=[23, 33, 43],\n encoder_window_spec=[8, 4, 16, 8],\n encoder_backbone_channel_list=[1152, 576, 288, 144],\n checkpoint=checkpoint,\n )", "chunk_type": "function", "name": "build_sam2_l", "file_path": "ultralytics\\ultralytics\\models\\sam\\build.py", "start_line": 108, "end_line": 118, "start_col": 0, "end_col": 5, "parent_name": null, "docstring": "Build and return a large-size Segment Anything Model 2 (SAM2) with specified architecture parameters.", "parameters": [ "checkpoint" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "functools.partial", "torch", "ultralytics.utils.downloads.attempt_download_asset", "modules.decoders.MaskDecoder", "modules.encoders.FpnNeck", "modules.encoders.Hiera", "modules.encoders.ImageEncoder", "modules.encoders.ImageEncoderViT", "modules.encoders.MemoryEncoder", "modules.encoders.PromptEncoder", "modules.memory_attention.MemoryAttention", "modules.memory_attention.MemoryAttentionLayer", "modules.sam.SAM2Model", "modules.sam.SAMModel", "modules.tiny_encoder.TinyViT", "modules.transformer.TwoWayTransformer" ], "chunk_id": "function_build_sam2_l_b5bf2557" }, { "content": "def _build_sam(\n encoder_embed_dim,\n encoder_depth,\n encoder_num_heads,\n encoder_global_attn_indexes,\n checkpoint=None,\n mobile_sam=False,\n):\n \"\"\"\n Build a Segment Anything Model (SAM) with specified encoder parameters.\n\n Args:\n encoder_embed_dim (int | List[int]): Embedding dimension for the encoder.\n encoder_depth (int | List[int]): Depth of the encoder.\n encoder_num_heads (int | List[int]): Number of attention heads in the encoder.\n encoder_global_attn_indexes (List[int] | None): Indexes for global attention in the encoder.\n checkpoint (str | None, optional): Path to the model checkpoint file.\n mobile_sam (bool, optional): Whether to build a Mobile-SAM model.\n\n Returns:\n (SAMModel): A Segment Anything Model instance with the specified architecture.\n\n Examples:\n >>> sam = _build_sam(768, 12, 12, [2, 5, 8, 11])\n >>> sam = _build_sam([64, 128, 160, 320], [2, 2, 6, 2], [2, 4, 5, 10], None, mobile_sam=True)\n \"\"\"\n prompt_embed_dim = 256\n image_size = 1024\n vit_patch_size = 16\n image_embedding_size = image_size // vit_patch_size\n image_encoder = (\n TinyViT(\n img_size=1024,\n in_chans=3,\n num_classes=1000,\n embed_dims=encoder_embed_dim,\n depths=encoder_depth,\n num_heads=encoder_num_heads,\n window_sizes=[7, 7, 14, 7],\n mlp_ratio=4.0,\n drop_rate=0.0,\n drop_path_rate=0.0,\n use_checkpoint=False,\n mbconv_expand_ratio=4.0,\n local_conv_size=3,\n layer_lr_decay=0.8,\n )\n if mobile_sam\n else ImageEncoderViT(\n depth=encoder_depth,\n embed_dim=encoder_embed_dim,\n img_size=image_size,\n mlp_ratio=4,\n norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),\n num_heads=encoder_num_heads,\n patch_size=vit_patch_size,\n qkv_bias=True,\n use_rel_pos=True,\n global_attn_indexes=encoder_global_attn_indexes,\n window_size=14,\n out_chans=prompt_embed_dim,\n )\n )\n sam = SAMModel(\n image_encoder=image_encoder,\n prompt_encoder=PromptEncoder(\n embed_dim=prompt_embed_dim,\n image_embedding_size=(image_embedding_size, image_embedding_size),\n input_image_size=(image_size, image_size),\n mask_in_chans=16,\n ),\n mask_decoder=MaskDecoder(\n num_multimask_outputs=3,\n transformer=TwoWayTransformer(\n depth=2,\n embedding_dim=prompt_embed_dim,\n mlp_dim=2048,\n num_heads=8,\n ),\n transformer_dim=prompt_embed_dim,\n iou_head_depth=3,\n iou_head_hidden_dim=256,\n ),\n pixel_mean=[123.675, 116.28, 103.53],\n pixel_std=[58.395, 57.12, 57.375],\n )\n if checkpoint is not None:\n checkpoint = attempt_download_asset(checkpoint)\n with open(checkpoint, \"rb\") as f:\n state_dict = torch.load(f)\n sam.load_state_dict(state_dict)\n sam.eval()\n return sam", "chunk_type": "function", "name": "_build_sam", "file_path": "ultralytics\\ultralytics\\models\\sam\\build.py", "start_line": 121, "end_line": 213, "start_col": 0, "end_col": 14, "parent_name": null, "docstring": "Build a Segment Anything Model (SAM) with specified encoder parameters.\n\nArgs:\n encoder_embed_dim (int | List[int]): Embedding dimension for the encoder.\n encoder_depth (int | List[int]): Depth of the encoder.\n encoder_num_heads (int | List[int]): Number of attention heads in the encoder.\n encoder_global_attn_indexes (List[int] | None): Indexes for global attention in the encoder.\n checkpoint (str | None, optional): Path to the model checkpoint file.\n mobile_sam (bool, optional): Whether to build a Mobile-SAM model.\n\nReturns:\n (SAMModel): A Segment Anything Model instance with the specified architecture.\n\nExamples:\n >>> sam = _build_sam(768, 12, 12, [2, 5, 8, 11])\n >>> sam = _build_sam([64, 128, 160, 320], [2, 2, 6, 2], [2, 4, 5, 10], None, mobile_sam=True)", "parameters": [ "encoder_embed_dim", "encoder_depth", "encoder_num_heads", "encoder_global_attn_indexes", "checkpoint", "mobile_sam" ], "return_type": null, "decorators": [], "complexity_score": 2, "dependencies": [ "functools.partial", "torch", "ultralytics.utils.downloads.attempt_download_asset", "modules.decoders.MaskDecoder", "modules.encoders.FpnNeck", "modules.encoders.Hiera", "modules.encoders.ImageEncoder", "modules.encoders.ImageEncoderViT", "modules.encoders.MemoryEncoder", "modules.encoders.PromptEncoder", "modules.memory_attention.MemoryAttention", "modules.memory_attention.MemoryAttentionLayer", "modules.sam.SAM2Model", "modules.sam.SAMModel", "modules.tiny_encoder.TinyViT", "modules.transformer.TwoWayTransformer" ], "chunk_id": "function__build_sam_ff65f01d" }, { "content": "def _build_sam2(\n encoder_embed_dim=1280,\n encoder_stages=[2, 6, 36, 4],\n encoder_num_heads=2,\n encoder_global_att_blocks=[7, 15, 23, 31],\n encoder_backbone_channel_list=[1152, 576, 288, 144],\n encoder_window_spatial_size=[7, 7],\n encoder_window_spec=[8, 4, 16, 8],\n checkpoint=None,\n):\n \"\"\"\n Build and return a Segment Anything Model 2 (SAM2) with specified architecture parameters.\n\n Args:\n encoder_embed_dim (int, optional): Embedding dimension for the encoder.\n encoder_stages (List[int], optional): Number of blocks in each stage of the encoder.\n encoder_num_heads (int, optional): Number of attention heads in the encoder.\n encoder_global_att_blocks (List[int], optional): Indices of global attention blocks in the encoder.\n encoder_backbone_channel_list (List[int], optional): Channel dimensions for each level of the encoder backbone.\n encoder_window_spatial_size (List[int], optional): Spatial size of the window for position embeddings.\n encoder_window_spec (List[int], optional): Window specifications for each stage of the encoder.\n checkpoint (str | None, optional): Path to the checkpoint file for loading pre-trained weights.\n\n Returns:\n (SAM2Model): A configured and initialized SAM2 model.\n\n Examples:\n >>> sam2_model = _build_sam2(encoder_embed_dim=96, encoder_stages=[1, 2, 7, 2])\n >>> sam2_model.eval()\n \"\"\"\n image_encoder = ImageEncoder(\n trunk=Hiera(\n embed_dim=encoder_embed_dim,\n num_heads=encoder_num_heads,\n stages=encoder_stages,\n global_att_blocks=encoder_global_att_blocks,\n window_pos_embed_bkg_spatial_size=encoder_window_spatial_size,\n window_spec=encoder_window_spec,\n ),\n neck=FpnNeck(\n d_model=256,\n backbone_channel_list=encoder_backbone_channel_list,\n fpn_top_down_levels=[2, 3],\n fpn_interp_model=\"nearest\",\n ),\n scalp=1,\n )\n memory_attention = MemoryAttention(d_model=256, pos_enc_at_input=True, num_layers=4, layer=MemoryAttentionLayer())\n memory_encoder = MemoryEncoder(out_dim=64)\n\n is_sam2_1 = checkpoint is not None and \"sam2.1\" in checkpoint\n sam2 = SAM2Model(\n image_encoder=image_encoder,\n memory_attention=memory_attention,\n memory_encoder=memory_encoder,\n num_maskmem=7,\n image_size=1024,\n sigmoid_scale_for_mem_enc=20.0,\n sigmoid_bias_for_mem_enc=-10.0,\n use_mask_input_as_output_without_sam=True,\n directly_add_no_mem_embed=True,\n use_high_res_features_in_sam=True,\n multimask_output_in_sam=True,\n iou_prediction_use_sigmoid=True,\n use_obj_ptrs_in_encoder=True,\n add_tpos_enc_to_obj_ptrs=True,\n only_obj_ptrs_in_the_past_for_eval=True,\n pred_obj_scores=True,\n pred_obj_scores_mlp=True,\n fixed_no_obj_ptr=True,\n multimask_output_for_tracking=True,\n use_multimask_token_for_obj_ptr=True,\n multimask_min_pt_num=0,\n multimask_max_pt_num=1,\n use_mlp_for_obj_ptr_proj=True,\n compile_image_encoder=False,\n no_obj_embed_spatial=is_sam2_1,\n proj_tpos_enc_in_obj_ptrs=is_sam2_1,\n use_signed_tpos_enc_to_obj_ptrs=is_sam2_1,\n sam_mask_decoder_extra_args=dict(\n dynamic_multimask_via_stability=True,\n dynamic_multimask_stability_delta=0.05,\n dynamic_multimask_stability_thresh=0.98,\n ),\n )\n\n if checkpoint is not None:\n checkpoint = attempt_download_asset(checkpoint)\n with open(checkpoint, \"rb\") as f:\n state_dict = torch.load(f)[\"model\"]\n sam2.load_state_dict(state_dict)\n sam2.eval()\n return sam2", "chunk_type": "function", "name": "_build_sam2", "file_path": "ultralytics\\ultralytics\\models\\sam\\build.py", "start_line": 216, "end_line": 308, "start_col": 0, "end_col": 15, "parent_name": null, "docstring": "Build and return a Segment Anything Model 2 (SAM2) with specified architecture parameters.\n\nArgs:\n encoder_embed_dim (int, optional): Embedding dimension for the encoder.\n encoder_stages (List[int], optional): Number of blocks in each stage of the encoder.\n encoder_num_heads (int, optional): Number of attention heads in the encoder.\n encoder_global_att_blocks (List[int], optional): Indices of global attention blocks in the encoder.\n encoder_backbone_channel_list (List[int], optional): Channel dimensions for each level of the encoder backbone.\n encoder_window_spatial_size (List[int], optional): Spatial size of the window for position embeddings.\n encoder_window_spec (List[int], optional): Window specifications for each stage of the encoder.\n checkpoint (str | None, optional): Path to the checkpoint file for loading pre-trained weights.\n\nReturns:\n (SAM2Model): A configured and initialized SAM2 model.\n\nExamples:\n >>> sam2_model = _build_sam2(encoder_embed_dim=96, encoder_stages=[1, 2, 7, 2])\n >>> sam2_model.eval()", "parameters": [ "encoder_embed_dim", "encoder_stages", "encoder_num_heads", "encoder_global_att_blocks", "encoder_backbone_channel_list", "encoder_window_spatial_size", "encoder_window_spec", "checkpoint" ], "return_type": null, "decorators": [], "complexity_score": 2, "dependencies": [ "functools.partial", "torch", "ultralytics.utils.downloads.attempt_download_asset", "modules.decoders.MaskDecoder", "modules.encoders.FpnNeck", "modules.encoders.Hiera", "modules.encoders.ImageEncoder", "modules.encoders.ImageEncoderViT", "modules.encoders.MemoryEncoder", "modules.encoders.PromptEncoder", "modules.memory_attention.MemoryAttention", "modules.memory_attention.MemoryAttentionLayer", "modules.sam.SAM2Model", "modules.sam.SAMModel", "modules.tiny_encoder.TinyViT", "modules.transformer.TwoWayTransformer" ], "chunk_id": "function__build_sam2_f613f9c0" }, { "content": "sam_model_map = {\n \"sam_h.pt\": build_sam_vit_h,\n \"sam_l.pt\": build_sam_vit_l,\n \"sam_b.pt\": build_sam_vit_b,\n \"mobile_sam.pt\": build_mobile_sam,\n \"sam2_t.pt\": build_sam2_t,\n \"sam2_s.pt\": build_sam2_s,\n \"sam2_b.pt\": build_sam2_b,\n \"sam2_l.pt\": build_sam2_l,\n \"sam2.1_t.pt\": build_sam2_t,\n \"sam2.1_s.pt\": build_sam2_s,\n \"sam2.1_b.pt\": build_sam2_b,\n \"sam2.1_l.pt\": build_sam2_l,\n}", "chunk_type": "variable", "name": "sam_model_map", "file_path": "ultralytics\\ultralytics\\models\\sam\\build.py", "start_line": 311, "end_line": 324, "start_col": 0, "end_col": 1, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_sam_model_map_98550379" }, { "content": "def build_sam(ckpt=\"sam_b.pt\"):\n \"\"\"\n Build and return a Segment Anything Model (SAM) based on the provided checkpoint.\n\n Args:\n ckpt (str | Path, optional): Path to the checkpoint file or name of a pre-defined SAM model.\n\n Returns:\n (SAMModel | SAM2Model): A configured and initialized SAM or SAM2 model instance.\n\n Raises:\n FileNotFoundError: If the provided checkpoint is not a supported SAM model.\n\n Examples:\n >>> sam_model = build_sam(\"sam_b.pt\")\n >>> sam_model = build_sam(\"path/to/custom_checkpoint.pt\")\n\n Notes:\n Supported pre-defined models include:\n - SAM: 'sam_h.pt', 'sam_l.pt', 'sam_b.pt', 'mobile_sam.pt'\n - SAM2: 'sam2_t.pt', 'sam2_s.pt', 'sam2_b.pt', 'sam2_l.pt'\n \"\"\"\n model_builder = None\n ckpt = str(ckpt) # to allow Path ckpt types\n for k in sam_model_map.keys():\n if ckpt.endswith(k):\n model_builder = sam_model_map.get(k)\n\n if not model_builder:\n raise FileNotFoundError(f\"{ckpt} is not a supported SAM model. Available models are: \\n {sam_model_map.keys()}\")\n\n return model_builder(ckpt)", "chunk_type": "function", "name": "build_sam", "file_path": "ultralytics\\ultralytics\\models\\sam\\build.py", "start_line": 327, "end_line": 358, "start_col": 0, "end_col": 30, "parent_name": null, "docstring": "Build and return a Segment Anything Model (SAM) based on the provided checkpoint.\n\nArgs:\n ckpt (str | Path, optional): Path to the checkpoint file or name of a pre-defined SAM model.\n\nReturns:\n (SAMModel | SAM2Model): A configured and initialized SAM or SAM2 model instance.\n\nRaises:\n FileNotFoundError: If the provided checkpoint is not a supported SAM model.\n\nExamples:\n >>> sam_model = build_sam(\"sam_b.pt\")\n >>> sam_model = build_sam(\"path/to/custom_checkpoint.pt\")\n\nNotes:\n Supported pre-defined models include:\n - SAM: 'sam_h.pt', 'sam_l.pt', 'sam_b.pt', 'mobile_sam.pt'\n - SAM2: 'sam2_t.pt', 'sam2_s.pt', 'sam2_b.pt', 'sam2_l.pt'", "parameters": [ "ckpt" ], "return_type": null, "decorators": [], "complexity_score": 4, "dependencies": [ "functools.partial", "torch", "ultralytics.utils.downloads.attempt_download_asset", "modules.decoders.MaskDecoder", "modules.encoders.FpnNeck", "modules.encoders.Hiera", "modules.encoders.ImageEncoder", "modules.encoders.ImageEncoderViT", "modules.encoders.MemoryEncoder", "modules.encoders.PromptEncoder", "modules.memory_attention.MemoryAttention", "modules.memory_attention.MemoryAttentionLayer", "modules.sam.SAM2Model", "modules.sam.SAMModel", "modules.tiny_encoder.TinyViT", "modules.transformer.TwoWayTransformer" ], "chunk_id": "function_build_sam_898d7f08" }, { "content": "from pathlib import Path", "chunk_type": "import", "name": "Path", "file_path": "ultralytics\\ultralytics\\models\\sam\\model.py", "start_line": 17, "end_line": 17, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Path_dce6a1cc" }, { "content": "from typing import Dict, Type", "chunk_type": "import", "name": "Dict, Type", "file_path": "ultralytics\\ultralytics\\models\\sam\\model.py", "start_line": 18, "end_line": 18, "start_col": 0, "end_col": 29, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Dict, Type_280e76d2" }, { "content": "from ultralytics.engine.model import Model", "chunk_type": "import", "name": "Model", "file_path": "ultralytics\\ultralytics\\models\\sam\\model.py", "start_line": 20, "end_line": 20, "start_col": 0, "end_col": 42, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Model_df6ac9de" }, { "content": "from ultralytics.utils.torch_utils import model_info", "chunk_type": "import", "name": "model_info", "file_path": "ultralytics\\ultralytics\\models\\sam\\model.py", "start_line": 21, "end_line": 21, "start_col": 0, "end_col": 52, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_model_info_9bba8c5a" }, { "content": "from .predict import Predictor, SAM2Predictor", "chunk_type": "import", "name": "Predictor, SAM2Predictor", "file_path": "ultralytics\\ultralytics\\models\\sam\\model.py", "start_line": 23, "end_line": 23, "start_col": 0, "end_col": 45, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Predictor, SAM2Predictor_dcfb9e46" }, { "content": "class SAM(Model):\n \"\"\"\n SAM (Segment Anything Model) interface class for real-time image segmentation tasks.\n\n This class provides an interface to the Segment Anything Model (SAM) from Ultralytics, designed for\n promptable segmentation with versatility in image analysis. It supports various prompts such as bounding\n boxes, points, or labels, and features zero-shot performance capabilities.\n\n Attributes:\n model (torch.nn.Module): The loaded SAM model.\n is_sam2 (bool): Indicates whether the model is SAM2 variant.\n task (str): The task type, set to \"segment\" for SAM models.\n\n Methods:\n predict: Perform segmentation prediction on the given image or video source.\n info: Log information about the SAM model.\n\n Examples:\n >>> sam = SAM(\"sam_b.pt\")\n >>> results = sam.predict(\"image.jpg\", points=[[500, 375]])\n >>> for r in results:\n >>> print(f\"Detected {len(r.masks)} masks\")\n \"\"\"\n\n def __init__(self, model: str = \"sam_b.pt\") -> None:\n \"\"\"\n Initialize the SAM (Segment Anything Model) instance.\n\n Args:\n model (str): Path to the pre-trained SAM model file. File should have a .pt or .pth extension.\n\n Raises:\n NotImplementedError: If the model file extension is not .pt or .pth.\n\n Examples:\n >>> sam = SAM(\"sam_b.pt\")\n >>> print(sam.is_sam2)\n \"\"\"\n if model and Path(model).suffix not in {\".pt\", \".pth\"}:\n raise NotImplementedError(\"SAM prediction requires pre-trained *.pt or *.pth model.\")\n self.is_sam2 = \"sam2\" in Path(model).stem\n super().__init__(model=model, task=\"segment\")\n\n def _load(self, weights: str, task=None):\n \"\"\"\n Load the specified weights into the SAM model.\n\n Args:\n weights (str): Path to the weights file. Should be a .pt or .pth file containing the model parameters.\n task (str | None): Task name. If provided, it specifies the particular task the model is being loaded for.\n\n Examples:\n >>> sam = SAM(\"sam_b.pt\")\n >>> sam._load(\"path/to/custom_weights.pt\")\n \"\"\"\n from .build import build_sam # slow import\n\n self.model = build_sam(weights)\n\n def predict(self, source, stream: bool = False, bboxes=None, points=None, labels=None, **kwargs):\n \"\"\"\n Perform segmentation prediction on the given image or video source.\n\n Args:\n source (str | PIL.Image | np.ndarray): Path to the image or video file, or a PIL.Image object, or\n a np.ndarray object.\n stream (bool): If True, enables real-time streaming.\n bboxes (List[List[float]] | None): List of bounding box coordinates for prompted segmentation.\n points (List[List[float]] | None): List of points for prompted segmentation.\n labels (List[int] | None): List of labels for prompted segmentation.\n **kwargs (Any): Additional keyword arguments for prediction.\n\n Returns:\n (list): The model predictions.\n\n Examples:\n >>> sam = SAM(\"sam_b.pt\")\n >>> results = sam.predict(\"image.jpg\", points=[[500, 375]])\n >>> for r in results:\n ... print(f\"Detected {len(r.masks)} masks\")\n \"\"\"\n overrides = dict(conf=0.25, task=\"segment\", mode=\"predict\", imgsz=1024)\n kwargs = {**overrides, **kwargs}\n prompts = dict(bboxes=bboxes, points=points, labels=labels)\n return super().predict(source, stream, prompts=prompts, **kwargs)\n\n def __call__(self, source=None, stream: bool = False, bboxes=None, points=None, labels=None, **kwargs):\n \"\"\"\n Perform segmentation prediction on the given image or video source.\n\n This method is an alias for the 'predict' method, providing a convenient way to call the SAM model\n for segmentation tasks.\n\n Args:\n source (str | PIL.Image | np.ndarray | None): Path to the image or video file, or a PIL.Image\n object, or a np.ndarray object.\n stream (bool): If True, enables real-time streaming.\n bboxes (List[List[float]] | None): List of bounding box coordinates for prompted segmentation.\n points (List[List[float]] | None): List of points for prompted segmentation.\n labels (List[int] | None): List of labels for prompted segmentation.\n **kwargs (Any): Additional keyword arguments to be passed to the predict method.\n\n Returns:\n (list): The model predictions, typically containing segmentation masks and other relevant information.\n\n Examples:\n >>> sam = SAM(\"sam_b.pt\")\n >>> results = sam(\"image.jpg\", points=[[500, 375]])\n >>> print(f\"Detected {len(results[0].masks)} masks\")\n \"\"\"\n return self.predict(source, stream, bboxes, points, labels, **kwargs)\n\n def info(self, detailed: bool = False, verbose: bool = True):\n \"\"\"\n Log information about the SAM model.\n\n Args:\n detailed (bool): If True, displays detailed information about the model layers and operations.\n verbose (bool): If True, prints the information to the console.\n\n Returns:\n (tuple): A tuple containing the model's information (string representations of the model).\n\n Examples:\n >>> sam = SAM(\"sam_b.pt\")\n >>> info = sam.info()\n >>> print(info[0]) # Print summary information\n \"\"\"\n return model_info(self.model, detailed=detailed, verbose=verbose)\n\n @property\n def task_map(self) -> Dict[str, Dict[str, Type[Predictor]]]:\n \"\"\"\n Provide a mapping from the 'segment' task to its corresponding 'Predictor'.\n\n Returns:\n (Dict[str, Dict[str, Type[Predictor]]]): A dictionary mapping the 'segment' task to its corresponding\n Predictor class. For SAM2 models, it maps to SAM2Predictor, otherwise to the standard Predictor.\n\n Examples:\n >>> sam = SAM(\"sam_b.pt\")\n >>> task_map = sam.task_map\n >>> print(task_map)\n {'segment': {'predictor': }}\n \"\"\"\n return {\"segment\": {\"predictor\": SAM2Predictor if self.is_sam2 else Predictor}}", "chunk_type": "class", "name": "SAM", "file_path": "ultralytics\\ultralytics\\models\\sam\\model.py", "start_line": 26, "end_line": 171, "start_col": 0, "end_col": 87, "parent_name": null, "docstring": "SAM (Segment Anything Model) interface class for real-time image segmentation tasks.\n\nThis class provides an interface to the Segment Anything Model (SAM) from Ultralytics, designed for\npromptable segmentation with versatility in image analysis. It supports various prompts such as bounding\nboxes, points, or labels, and features zero-shot performance capabilities.\n\nAttributes:\n model (torch.nn.Module): The loaded SAM model.\n is_sam2 (bool): Indicates whether the model is SAM2 variant.\n task (str): The task type, set to \"segment\" for SAM models.\n\nMethods:\n predict: Perform segmentation prediction on the given image or video source.\n info: Log information about the SAM model.\n\nExamples:\n >>> sam = SAM(\"sam_b.pt\")\n >>> results = sam.predict(\"image.jpg\", points=[[500, 375]])\n >>> for r in results:\n >>> print(f\"Detected {len(r.masks)} masks\")", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "pathlib.Path", "typing.Dict", "typing.Type", "ultralytics.engine.model.Model", "ultralytics.utils.torch_utils.model_info", "predict.Predictor", "predict.SAM2Predictor", "build.build_sam", "Model" ], "chunk_id": "class_SAM_9775bc2e" }, { "content": "from collections import OrderedDict", "chunk_type": "import", "name": "OrderedDict", "file_path": "ultralytics\\ultralytics\\models\\sam\\predict.py", "start_line": 11, "end_line": 11, "start_col": 0, "end_col": 35, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_OrderedDict_ac6d9e8b" }, { "content": "import numpy as np", "chunk_type": "import", "name": "numpy", "file_path": "ultralytics\\ultralytics\\models\\sam\\predict.py", "start_line": 13, "end_line": 13, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_numpy_0942c671" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\ultralytics\\models\\sam\\predict.py", "start_line": 14, "end_line": 14, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_ffeb0bc7" }, { "content": "import torch.nn.functional as F", "chunk_type": "import", "name": "torch.nn.functional", "file_path": "ultralytics\\ultralytics\\models\\sam\\predict.py", "start_line": 15, "end_line": 15, "start_col": 0, "end_col": 31, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch.nn.functional_4464cd3c" }, { "content": "from ultralytics.data.augment import LetterBox", "chunk_type": "import", "name": "LetterBox", "file_path": "ultralytics\\ultralytics\\models\\sam\\predict.py", "start_line": 17, "end_line": 17, "start_col": 0, "end_col": 46, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_LetterBox_76b72a38" }, { "content": "from ultralytics.engine.predictor import BasePredictor", "chunk_type": "import", "name": "BasePredictor", "file_path": "ultralytics\\ultralytics\\models\\sam\\predict.py", "start_line": 18, "end_line": 18, "start_col": 0, "end_col": 54, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_BasePredictor_5265f07c" }, { "content": "from ultralytics.engine.results import Results", "chunk_type": "import", "name": "Results", "file_path": "ultralytics\\ultralytics\\models\\sam\\predict.py", "start_line": 19, "end_line": 19, "start_col": 0, "end_col": 46, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Results_90d1e37f" }, { "content": "from ultralytics.utils import DEFAULT_CFG, ops", "chunk_type": "import", "name": "DEFAULT_CFG, ops", "file_path": "ultralytics\\ultralytics\\models\\sam\\predict.py", "start_line": 20, "end_line": 20, "start_col": 0, "end_col": 46, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_DEFAULT_CFG, ops_37bd5d37" }, { "content": "from ultralytics.utils.torch_utils import select_device, smart_inference_mode", "chunk_type": "import", "name": "select_device, smart_inference_mode", "file_path": "ultralytics\\ultralytics\\models\\sam\\predict.py", "start_line": 21, "end_line": 21, "start_col": 0, "end_col": 77, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_select_device, smart_inference_mode_c00f2f8b" }, { "content": "from .amg import (\n batch_iterator,\n batched_mask_to_box,\n build_all_layer_point_grids,\n calculate_stability_score,\n generate_crop_boxes,\n is_box_near_crop_edge,\n remove_small_regions,\n uncrop_boxes_xyxy,\n uncrop_masks,\n)", "chunk_type": "import", "name": "batch_iterator, batched_mask_to_box, build_all_layer_point_grids, calculate_stability_score, generate_crop_boxes, is_box_near_crop_edge, remove_small_regions, uncrop_boxes_xyxy, uncrop_masks", "file_path": "ultralytics\\ultralytics\\models\\sam\\predict.py", "start_line": 23, "end_line": 33, "start_col": 0, "end_col": 1, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_batch_iterator, batched_mask_to_box, build_all_layer_point_grids, calculate_stability_score, generate_crop_boxes, is_box_near_crop_edge, remove_small_regions, uncrop_boxes_xyxy, uncrop_masks_682436ba" }, { "content": "class Predictor(BasePredictor):\n \"\"\"\n Predictor class for SAM, enabling real-time image segmentation with promptable capabilities.\n\n This class extends BasePredictor and implements the Segment Anything Model (SAM) for advanced image\n segmentation tasks. It supports various input prompts like points, bounding boxes, and masks for\n fine-grained control over segmentation results.\n\n Attributes:\n args (SimpleNamespace): Configuration arguments for the predictor.\n model (torch.nn.Module): The loaded SAM model.\n device (torch.device): The device (CPU or GPU) on which the model is loaded.\n im (torch.Tensor): The preprocessed input image.\n features (torch.Tensor): Extracted image features.\n prompts (Dict[str, Any]): Dictionary to store various types of prompts (e.g., bboxes, points, masks).\n segment_all (bool): Flag to indicate if full image segmentation should be performed.\n mean (torch.Tensor): Mean values for image normalization.\n std (torch.Tensor): Standard deviation values for image normalization.\n\n Methods:\n preprocess: Prepare input images for model inference.\n pre_transform: Perform initial transformations on the input image.\n inference: Perform segmentation inference based on input prompts.\n prompt_inference: Internal function for prompt-based segmentation inference.\n generate: Generate segmentation masks for an entire image.\n setup_model: Initialize the SAM model for inference.\n get_model: Build and return a SAM model.\n postprocess: Post-process model outputs to generate final results.\n setup_source: Set up the data source for inference.\n set_image: Set and preprocess a single image for inference.\n get_im_features: Extract image features using the SAM image encoder.\n set_prompts: Set prompts for subsequent inference.\n reset_image: Reset the current image and its features.\n remove_small_regions: Remove small disconnected regions and holes from masks.\n\n Examples:\n >>> predictor = Predictor()\n >>> predictor.setup_model(model_path=\"sam_model.pt\")\n >>> predictor.set_image(\"image.jpg\")\n >>> bboxes = [[100, 100, 200, 200]]\n >>> results = predictor(bboxes=bboxes)\n \"\"\"\n\n def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):\n \"\"\"\n Initialize the Predictor with configuration, overrides, and callbacks.\n\n Sets up the Predictor object for SAM (Segment Anything Model) and applies any configuration overrides or\n callbacks provided. Initializes task-specific settings for SAM, such as retina_masks being set to True\n for optimal results.\n\n Args:\n cfg (dict): Configuration dictionary containing default settings.\n overrides (dict | None): Dictionary of values to override default configuration.\n _callbacks (dict | None): Dictionary of callback functions to customize behavior.\n\n Examples:\n >>> predictor_example = Predictor(cfg=DEFAULT_CFG)\n >>> predictor_example_with_imgsz = Predictor(overrides={\"imgsz\": 640})\n >>> predictor_example_with_callback = Predictor(_callbacks={\"on_predict_start\": custom_callback})\n \"\"\"\n if overrides is None:\n overrides = {}\n overrides.update(dict(task=\"segment\", mode=\"predict\", batch=1))\n super().__init__(cfg, overrides, _callbacks)\n self.args.retina_masks = True\n self.im = None\n self.features = None\n self.prompts = {}\n self.segment_all = False\n\n def preprocess(self, im):\n \"\"\"\n Preprocess the input image for model inference.\n\n This method prepares the input image by applying transformations and normalization. It supports both\n torch.Tensor and list of np.ndarray as input formats.\n\n Args:\n im (torch.Tensor | List[np.ndarray]): Input image(s) in BCHW tensor format or list of HWC numpy arrays.\n\n Returns:\n (torch.Tensor): The preprocessed image tensor, normalized and converted to the appropriate dtype.\n\n Examples:\n >>> predictor = Predictor()\n >>> image = torch.rand(1, 3, 640, 640)\n >>> preprocessed_image = predictor.preprocess(image)\n \"\"\"\n if self.im is not None:\n return self.im\n not_tensor = not isinstance(im, torch.Tensor)\n if not_tensor:\n im = np.stack(self.pre_transform(im))\n im = im[..., ::-1].transpose((0, 3, 1, 2))\n im = np.ascontiguousarray(im)\n im = torch.from_numpy(im)\n\n im = im.to(self.device)\n im = im.half() if self.model.fp16 else im.float()\n if not_tensor:\n im = (im - self.mean) / self.std\n return im\n\n def pre_transform(self, im):\n \"\"\"\n Perform initial transformations on the input image for preprocessing.\n\n This method applies transformations such as resizing to prepare the image for further preprocessing.\n Currently, batched inference is not supported; hence the list length should be 1.\n\n Args:\n im (List[np.ndarray]): List containing a single image in HWC numpy array format.\n\n Returns:\n (List[np.ndarray]): List containing the transformed image.\n\n Raises:\n AssertionError: If the input list contains more than one image.\n\n Examples:\n >>> predictor = Predictor()\n >>> image = np.random.rand(480, 640, 3) # Single HWC image\n >>> transformed = predictor.pre_transform([image])\n >>> print(len(transformed))\n 1\n \"\"\"\n assert len(im) == 1, \"SAM model does not currently support batched inference\"\n letterbox = LetterBox(self.args.imgsz, auto=False, center=False)\n return [letterbox(image=x) for x in im]\n\n def inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False, *args, **kwargs):\n \"\"\"\n Perform image segmentation inference based on the given input cues, using the currently loaded image.\n\n This method leverages SAM's (Segment Anything Model) architecture consisting of image encoder, prompt\n encoder, and mask decoder for real-time and promptable segmentation tasks.\n\n Args:\n im (torch.Tensor): The preprocessed input image in tensor format, with shape (N, C, H, W).\n bboxes (np.ndarray | List | None): Bounding boxes with shape (N, 4), in XYXY format.\n points (np.ndarray | List | None): Points indicating object locations with shape (N, 2), in pixels.\n labels (np.ndarray | List | None): Labels for point prompts, shape (N,). 1 = foreground, 0 = background.\n masks (np.ndarray | None): Low-resolution masks from previous predictions, shape (N, H, W). For SAM H=W=256.\n multimask_output (bool): Flag to return multiple masks. Helpful for ambiguous prompts.\n *args (Any): Additional positional arguments.\n **kwargs (Any): Additional keyword arguments.\n\n Returns:\n pred_masks (np.ndarray): The output masks in shape (C, H, W), where C is the number of generated masks.\n pred_scores (np.ndarray): An array of length C containing quality scores predicted by the model for each mask.\n pred_logits (np.ndarray): Low-resolution logits of shape (C, H, W) for subsequent inference, where H=W=256.\n\n Examples:\n >>> predictor = Predictor()\n >>> predictor.setup_model(model_path=\"sam_model.pt\")\n >>> predictor.set_image(\"image.jpg\")\n >>> results = predictor(bboxes=[[0, 0, 100, 100]])\n \"\"\"\n # Override prompts if any stored in self.prompts\n bboxes = self.prompts.pop(\"bboxes\", bboxes)\n points = self.prompts.pop(\"points\", points)\n masks = self.prompts.pop(\"masks\", masks)\n labels = self.prompts.pop(\"labels\", labels)\n\n if all(i is None for i in [bboxes, points, masks]):\n return self.generate(im, *args, **kwargs)\n\n return self.prompt_inference(im, bboxes, points, labels, masks, multimask_output)\n\n def prompt_inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False):\n \"\"\"\n Perform image segmentation inference based on input cues using SAM's specialized architecture.\n\n This internal function leverages the Segment Anything Model (SAM) for prompt-based, real-time segmentation.\n It processes various input prompts such as bounding boxes, points, and masks to generate segmentation masks.\n\n Args:\n im (torch.Tensor): Preprocessed input image tensor with shape (N, C, H, W).\n bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4).\n points (np.ndarray | List | None): Points indicating object locations with shape (N, 2) or (N, num_points, 2), in pixels.\n labels (np.ndarray | List | None): Point prompt labels with shape (N) or (N, num_points). 1 for foreground, 0 for background.\n masks (np.ndarray | None): Low-res masks from previous predictions with shape (N, H, W). For SAM, H=W=256.\n multimask_output (bool): Flag to return multiple masks for ambiguous prompts.\n\n Returns:\n pred_masks (np.ndarray): Output masks with shape (C, H, W), where C is the number of generated masks.\n pred_scores (np.ndarray): Quality scores predicted by the model for each mask, with length C.\n\n Examples:\n >>> predictor = Predictor()\n >>> im = torch.rand(1, 3, 1024, 1024)\n >>> bboxes = [[100, 100, 200, 200]]\n >>> masks, scores, logits = predictor.prompt_inference(im, bboxes=bboxes)\n \"\"\"\n features = self.get_im_features(im) if self.features is None else self.features\n\n bboxes, points, labels, masks = self._prepare_prompts(im.shape[2:], bboxes, points, labels, masks)\n points = (points, labels) if points is not None else None\n # Embed prompts\n sparse_embeddings, dense_embeddings = self.model.prompt_encoder(points=points, boxes=bboxes, masks=masks)\n\n # Predict masks\n pred_masks, pred_scores = self.model.mask_decoder(\n image_embeddings=features,\n image_pe=self.model.prompt_encoder.get_dense_pe(),\n sparse_prompt_embeddings=sparse_embeddings,\n dense_prompt_embeddings=dense_embeddings,\n multimask_output=multimask_output,\n )\n\n # (N, d, H, W) --> (N*d, H, W), (N, d) --> (N*d, )\n # `d` could be 1 or 3 depends on `multimask_output`.\n return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1)\n\n def _prepare_prompts(self, dst_shape, bboxes=None, points=None, labels=None, masks=None):\n \"\"\"\n Prepare and transform the input prompts for processing based on the destination shape.\n\n Args:\n dst_shape (tuple): The target shape (height, width) for the prompts.\n bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4).\n points (np.ndarray | List | None): Points indicating object locations with shape (N, 2) or (N, num_points, 2), in pixels.\n labels (np.ndarray | List | None): Point prompt labels with shape (N) or (N, num_points). 1 for foreground, 0 for background.\n masks (List | np.ndarray | None): Masks for the objects, where each mask is a 2D array.\n\n Returns:\n bboxes (torch.Tensor | None): Transformed bounding boxes.\n points (torch.Tensor | None): Transformed points.\n labels (torch.Tensor | None): Transformed labels.\n masks (torch.Tensor | None): Transformed masks.\n\n Raises:\n AssertionError: If the number of points don't match the number of labels, in case labels were passed.\n \"\"\"\n src_shape = self.batch[1][0].shape[:2]\n r = 1.0 if self.segment_all else min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1])\n # Transform input prompts\n if points is not None:\n points = torch.as_tensor(points, dtype=torch.float32, device=self.device)\n points = points[None] if points.ndim == 1 else points\n # Assuming labels are all positive if users don't pass labels.\n if labels is None:\n labels = np.ones(points.shape[:-1])\n labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device)\n assert points.shape[-2] == labels.shape[-1], (\n f\"Number of points {points.shape[-2]} should match number of labels {labels.shape[-1]}.\"\n )\n points *= r\n if points.ndim == 2:\n # (N, 2) --> (N, 1, 2), (N, ) --> (N, 1)\n points, labels = points[:, None, :], labels[:, None]\n if bboxes is not None:\n bboxes = torch.as_tensor(bboxes, dtype=torch.float32, device=self.device)\n bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes\n bboxes *= r\n if masks is not None:\n masks = torch.as_tensor(masks, dtype=torch.float32, device=self.device).unsqueeze(1)\n return bboxes, points, labels, masks\n\n def generate(\n self,\n im,\n crop_n_layers=0,\n crop_overlap_ratio=512 / 1500,\n crop_downscale_factor=1,\n point_grids=None,\n points_stride=32,\n points_batch_size=64,\n conf_thres=0.88,\n stability_score_thresh=0.95,\n stability_score_offset=0.95,\n crop_nms_thresh=0.7,\n ):\n \"\"\"\n Perform image segmentation using the Segment Anything Model (SAM).\n\n This method segments an entire image into constituent parts by leveraging SAM's advanced architecture\n and real-time performance capabilities. It can optionally work on image crops for finer segmentation.\n\n Args:\n im (torch.Tensor): Input tensor representing the preprocessed image with shape (N, C, H, W).\n crop_n_layers (int): Number of layers for additional mask predictions on image crops.\n crop_overlap_ratio (float): Overlap between crops, scaled down in subsequent layers.\n crop_downscale_factor (int): Scaling factor for sampled points-per-side in each layer.\n point_grids (List[np.ndarray] | None): Custom grids for point sampling normalized to [0,1].\n points_stride (int): Number of points to sample along each side of the image.\n points_batch_size (int): Batch size for the number of points processed simultaneously.\n conf_thres (float): Confidence threshold [0,1] for filtering based on mask quality prediction.\n stability_score_thresh (float): Stability threshold [0,1] for mask filtering based on stability.\n stability_score_offset (float): Offset value for calculating stability score.\n crop_nms_thresh (float): IoU cutoff for NMS to remove duplicate masks between crops.\n\n Returns:\n pred_masks (torch.Tensor): Segmented masks with shape (N, H, W).\n pred_scores (torch.Tensor): Confidence scores for each mask with shape (N,).\n pred_bboxes (torch.Tensor): Bounding boxes for each mask with shape (N, 4).\n\n Examples:\n >>> predictor = Predictor()\n >>> im = torch.rand(1, 3, 1024, 1024) # Example input image\n >>> masks, scores, boxes = predictor.generate(im)\n \"\"\"\n import torchvision # scope for faster 'import ultralytics'\n\n self.segment_all = True\n ih, iw = im.shape[2:]\n crop_regions, layer_idxs = generate_crop_boxes((ih, iw), crop_n_layers, crop_overlap_ratio)\n if point_grids is None:\n point_grids = build_all_layer_point_grids(points_stride, crop_n_layers, crop_downscale_factor)\n pred_masks, pred_scores, pred_bboxes, region_areas = [], [], [], []\n for crop_region, layer_idx in zip(crop_regions, layer_idxs):\n x1, y1, x2, y2 = crop_region\n w, h = x2 - x1, y2 - y1\n area = torch.tensor(w * h, device=im.device)\n points_scale = np.array([[w, h]]) # w, h\n # Crop image and interpolate to input size\n crop_im = F.interpolate(im[..., y1:y2, x1:x2], (ih, iw), mode=\"bilinear\", align_corners=False)\n # (num_points, 2)\n points_for_image = point_grids[layer_idx] * points_scale\n crop_masks, crop_scores, crop_bboxes = [], [], []\n for (points,) in batch_iterator(points_batch_size, points_for_image):\n pred_mask, pred_score = self.prompt_inference(crop_im, points=points, multimask_output=True)\n # Interpolate predicted masks to input size\n pred_mask = F.interpolate(pred_mask[None], (h, w), mode=\"bilinear\", align_corners=False)[0]\n idx = pred_score > conf_thres\n pred_mask, pred_score = pred_mask[idx], pred_score[idx]\n\n stability_score = calculate_stability_score(\n pred_mask, self.model.mask_threshold, stability_score_offset\n )\n idx = stability_score > stability_score_thresh\n pred_mask, pred_score = pred_mask[idx], pred_score[idx]\n # Bool type is much more memory-efficient.\n pred_mask = pred_mask > self.model.mask_threshold\n # (N, 4)\n pred_bbox = batched_mask_to_box(pred_mask).float()\n keep_mask = ~is_box_near_crop_edge(pred_bbox, crop_region, [0, 0, iw, ih])\n if not torch.all(keep_mask):\n pred_bbox, pred_mask, pred_score = pred_bbox[keep_mask], pred_mask[keep_mask], pred_score[keep_mask]\n\n crop_masks.append(pred_mask)\n crop_bboxes.append(pred_bbox)\n crop_scores.append(pred_score)\n\n # Do nms within this crop\n crop_masks = torch.cat(crop_masks)\n crop_bboxes = torch.cat(crop_bboxes)\n crop_scores = torch.cat(crop_scores)\n keep = torchvision.ops.nms(crop_bboxes, crop_scores, self.args.iou) # NMS\n crop_bboxes = uncrop_boxes_xyxy(crop_bboxes[keep], crop_region)\n crop_masks = uncrop_masks(crop_masks[keep], crop_region, ih, iw)\n crop_scores = crop_scores[keep]\n\n pred_masks.append(crop_masks)\n pred_bboxes.append(crop_bboxes)\n pred_scores.append(crop_scores)\n region_areas.append(area.expand(len(crop_masks)))\n\n pred_masks = torch.cat(pred_masks)\n pred_bboxes = torch.cat(pred_bboxes)\n pred_scores = torch.cat(pred_scores)\n region_areas = torch.cat(region_areas)\n\n # Remove duplicate masks between crops\n if len(crop_regions) > 1:\n scores = 1 / region_areas\n keep = torchvision.ops.nms(pred_bboxes, scores, crop_nms_thresh)\n pred_masks, pred_bboxes, pred_scores = pred_masks[keep], pred_bboxes[keep], pred_scores[keep]\n\n return pred_masks, pred_scores, pred_bboxes\n\n def setup_model(self, model=None, verbose=True):\n \"\"\"\n Initialize the Segment Anything Model (SAM) for inference.\n\n This method sets up the SAM model by allocating it to the appropriate device and initializing the necessary\n parameters for image normalization and other Ultralytics compatibility settings.\n\n Args:\n model (torch.nn.Module | None): A pretrained SAM model. If None, a new model is built based on config.\n verbose (bool): If True, prints selected device information.\n\n Examples:\n >>> predictor = Predictor()\n >>> predictor.setup_model(model=sam_model, verbose=True)\n \"\"\"\n device = select_device(self.args.device, verbose=verbose)\n if model is None:\n model = self.get_model()\n model.eval()\n self.model = model.to(device)\n self.device = device\n self.mean = torch.tensor([123.675, 116.28, 103.53]).view(-1, 1, 1).to(device)\n self.std = torch.tensor([58.395, 57.12, 57.375]).view(-1, 1, 1).to(device)\n\n # Ultralytics compatibility settings\n self.model.pt = False\n self.model.triton = False\n self.model.stride = 32\n self.model.fp16 = False\n self.done_warmup = True\n\n def get_model(self):\n \"\"\"Retrieve or build the Segment Anything Model (SAM) for image segmentation tasks.\"\"\"\n from .build import build_sam # slow import\n\n return build_sam(self.args.model)\n\n def postprocess(self, preds, img, orig_imgs):\n \"\"\"\n Post-process SAM's inference outputs to generate object detection masks and bounding boxes.\n\n This method scales masks and boxes to the original image size and applies a threshold to the mask\n predictions. It leverages SAM's advanced architecture for real-time, promptable segmentation tasks.\n\n Args:\n preds (tuple): The output from SAM model inference, containing:\n - pred_masks (torch.Tensor): Predicted masks with shape (N, 1, H, W).\n - pred_scores (torch.Tensor): Confidence scores for each mask with shape (N, 1).\n - pred_bboxes (torch.Tensor, optional): Predicted bounding boxes if segment_all is True.\n img (torch.Tensor): The processed input image tensor with shape (C, H, W).\n orig_imgs (List[np.ndarray] | torch.Tensor): The original, unprocessed images.\n\n Returns:\n (List[Results]): List of Results objects containing detection masks, bounding boxes, and other\n metadata for each processed image.\n\n Examples:\n >>> predictor = Predictor()\n >>> preds = predictor.inference(img)\n >>> results = predictor.postprocess(preds, img, orig_imgs)\n \"\"\"\n # (N, 1, H, W), (N, 1)\n pred_masks, pred_scores = preds[:2]\n pred_bboxes = preds[2] if self.segment_all else None\n names = dict(enumerate(str(i) for i in range(len(pred_masks))))\n\n if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list\n orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)\n\n results = []\n for masks, orig_img, img_path in zip([pred_masks], orig_imgs, self.batch[0]):\n if len(masks) == 0:\n masks, pred_bboxes = None, torch.zeros((0, 6), device=pred_masks.device)\n else:\n masks = ops.scale_masks(masks[None].float(), orig_img.shape[:2], padding=False)[0]\n masks = masks > self.model.mask_threshold # to bool\n if pred_bboxes is not None:\n pred_bboxes = ops.scale_boxes(img.shape[2:], pred_bboxes.float(), orig_img.shape, padding=False)\n else:\n pred_bboxes = batched_mask_to_box(masks)\n # NOTE: SAM models do not return cls info. This `cls` here is just a placeholder for consistency.\n cls = torch.arange(len(pred_masks), dtype=torch.int32, device=pred_masks.device)\n pred_bboxes = torch.cat([pred_bboxes, pred_scores[:, None], cls[:, None]], dim=-1)\n results.append(Results(orig_img, path=img_path, names=names, masks=masks, boxes=pred_bboxes))\n # Reset segment-all mode.\n self.segment_all = False\n return results\n\n def setup_source(self, source):\n \"\"\"\n Set up the data source for inference.\n\n This method configures the data source from which images will be fetched for inference. It supports\n various input types such as image files, directories, video files, and other compatible data sources.\n\n Args:\n source (str | Path | None): The path or identifier for the image data source. Can be a file path,\n directory path, URL, or other supported source types.\n\n Examples:\n >>> predictor = Predictor()\n >>> predictor.setup_source(\"path/to/images\")\n >>> predictor.setup_source(\"video.mp4\")\n >>> predictor.setup_source(None) # Uses default source if available\n\n Notes:\n - If source is None, the method may use a default source if configured.\n - The method adapts to different source types and prepares them for subsequent inference steps.\n - Supported source types may include local files, directories, URLs, and video streams.\n \"\"\"\n if source is not None:\n super().setup_source(source)\n\n def set_image(self, image):\n \"\"\"\n Preprocess and set a single image for inference.\n\n This method prepares the model for inference on a single image by setting up the model if not already\n initialized, configuring the data source, and preprocessing the image for feature extraction. It\n ensures that only one image is set at a time and extracts image features for subsequent use.\n\n Args:\n image (str | np.ndarray): Path to the image file as a string, or a numpy array representing\n an image read by cv2.\n\n Examples:\n >>> predictor = Predictor()\n >>> predictor.set_image(\"path/to/image.jpg\")\n >>> predictor.set_image(cv2.imread(\"path/to/image.jpg\"))\n\n Raises:\n AssertionError: If more than one image is attempted to be set.\n\n Notes:\n - This method should be called before performing inference on a new image.\n - The extracted features are stored in the `self.features` attribute for later use.\n \"\"\"\n if self.model is None:\n self.setup_model(model=None)\n self.setup_source(image)\n assert len(self.dataset) == 1, \"`set_image` only supports setting one image!\"\n for batch in self.dataset:\n im = self.preprocess(batch[1])\n self.features = self.get_im_features(im)\n break\n\n def get_im_features(self, im):\n \"\"\"Extract image features using the SAM model's image encoder for subsequent mask prediction.\"\"\"\n assert isinstance(self.imgsz, (tuple, list)) and self.imgsz[0] == self.imgsz[1], (\n f\"SAM models only support square image size, but got {self.imgsz}.\"\n )\n self.model.set_imgsz(self.imgsz)\n return self.model.image_encoder(im)\n\n def set_prompts(self, prompts):\n \"\"\"Set prompts for subsequent inference operations.\"\"\"\n self.prompts = prompts\n\n def reset_image(self):\n \"\"\"Reset the current image and its features, clearing them for subsequent inference.\"\"\"\n self.im = None\n self.features = None\n\n @staticmethod\n def remove_small_regions(masks, min_area=0, nms_thresh=0.7):\n \"\"\"\n Remove small disconnected regions and holes from segmentation masks.\n\n This function performs post-processing on segmentation masks generated by the Segment Anything Model (SAM).\n It removes small disconnected regions and holes from the input masks, and then performs Non-Maximum\n Suppression (NMS) to eliminate any newly created duplicate boxes.\n\n Args:\n masks (torch.Tensor): Segmentation masks to be processed, with shape (N, H, W) where N is the number of\n masks, H is height, and W is width.\n min_area (int): Minimum area threshold for removing disconnected regions and holes. Regions smaller than\n this will be removed.\n nms_thresh (float): IoU threshold for the NMS algorithm to remove duplicate boxes.\n\n Returns:\n new_masks (torch.Tensor): Processed masks with small regions removed, shape (N, H, W).\n keep (List[int]): Indices of remaining masks after NMS, for filtering corresponding boxes.\n\n Examples:\n >>> masks = torch.rand(5, 640, 640) > 0.5 # 5 random binary masks\n >>> new_masks, keep = remove_small_regions(masks, min_area=100, nms_thresh=0.7)\n >>> print(f\"Original masks: {masks.shape}, Processed masks: {new_masks.shape}\")\n >>> print(f\"Indices of kept masks: {keep}\")\n \"\"\"\n import torchvision # scope for faster 'import ultralytics'\n\n if len(masks) == 0:\n return masks\n\n # Filter small disconnected regions and holes\n new_masks = []\n scores = []\n for mask in masks:\n mask = mask.cpu().numpy().astype(np.uint8)\n mask, changed = remove_small_regions(mask, min_area, mode=\"holes\")\n unchanged = not changed\n mask, changed = remove_small_regions(mask, min_area, mode=\"islands\")\n unchanged = unchanged and not changed\n\n new_masks.append(torch.as_tensor(mask).unsqueeze(0))\n # Give score=0 to changed masks and 1 to unchanged masks so NMS prefers masks not needing postprocessing\n scores.append(float(unchanged))\n\n # Recalculate boxes and remove any new duplicates\n new_masks = torch.cat(new_masks, dim=0)\n boxes = batched_mask_to_box(new_masks)\n keep = torchvision.ops.nms(boxes.float(), torch.as_tensor(scores), nms_thresh)\n\n return new_masks[keep].to(device=masks.device, dtype=masks.dtype), keep", "chunk_type": "class", "name": "Predictor", "file_path": "ultralytics\\ultralytics\\models\\sam\\predict.py", "start_line": 36, "end_line": 621, "start_col": 0, "end_col": 79, "parent_name": null, "docstring": "Predictor class for SAM, enabling real-time image segmentation with promptable capabilities.\n\nThis class extends BasePredictor and implements the Segment Anything Model (SAM) for advanced image\nsegmentation tasks. It supports various input prompts like points, bounding boxes, and masks for\nfine-grained control over segmentation results.\n\nAttributes:\n args (SimpleNamespace): Configuration arguments for the predictor.\n model (torch.nn.Module): The loaded SAM model.\n device (torch.device): The device (CPU or GPU) on which the model is loaded.\n im (torch.Tensor): The preprocessed input image.\n features (torch.Tensor): Extracted image features.\n prompts (Dict[str, Any]): Dictionary to store various types of prompts (e.g., bboxes, points, masks).\n segment_all (bool): Flag to indicate if full image segmentation should be performed.\n mean (torch.Tensor): Mean values for image normalization.\n std (torch.Tensor): Standard deviation values for image normalization.\n\nMethods:\n preprocess: Prepare input images for model inference.\n pre_transform: Perform initial transformations on the input image.\n inference: Perform segmentation inference based on input prompts.\n prompt_inference: Internal function for prompt-based segmentation inference.\n generate: Generate segmentation masks for an entire image.\n setup_model: Initialize the SAM model for inference.\n get_model: Build and return a SAM model.\n postprocess: Post-process model outputs to generate final results.\n setup_source: Set up the data source for inference.\n set_image: Set and preprocess a single image for inference.\n get_im_features: Extract image features using the SAM image encoder.\n set_prompts: Set prompts for subsequent inference.\n reset_image: Reset the current image and its features.\n remove_small_regions: Remove small disconnected regions and holes from masks.\n\nExamples:\n >>> predictor = Predictor()\n >>> predictor.setup_model(model_path=\"sam_model.pt\")\n >>> predictor.set_image(\"image.jpg\")\n >>> bboxes = [[100, 100, 200, 200]]\n >>> results = predictor(bboxes=bboxes)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "collections.OrderedDict", "numpy", "torch", "torch.nn.functional", "ultralytics.data.augment.LetterBox", "ultralytics.engine.predictor.BasePredictor", "ultralytics.engine.results.Results", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.ops", "ultralytics.utils.torch_utils.select_device", "ultralytics.utils.torch_utils.smart_inference_mode", "amg.batch_iterator", "amg.batched_mask_to_box", "amg.build_all_layer_point_grids", "amg.calculate_stability_score", "amg.generate_crop_boxes", "amg.is_box_near_crop_edge", "amg.remove_small_regions", "amg.uncrop_boxes_xyxy", "amg.uncrop_masks", "torchvision", "build.build_sam", "torchvision", "build.build_sam", "BasePredictor" ], "chunk_id": "class_Predictor_3cfa213f" }, { "content": "class SAM2Predictor(Predictor):\n \"\"\"\n SAM2Predictor class for advanced image segmentation using Segment Anything Model 2 architecture.\n\n This class extends the base Predictor class to implement SAM2-specific functionality for image\n segmentation tasks. It provides methods for model initialization, feature extraction, and\n prompt-based inference.\n\n Attributes:\n _bb_feat_sizes (List[tuple]): Feature sizes for different backbone levels.\n model (torch.nn.Module): The loaded SAM2 model.\n device (torch.device): The device (CPU or GPU) on which the model is loaded.\n features (dict): Cached image features for efficient inference.\n segment_all (bool): Flag to indicate if all segments should be predicted.\n prompts (Dict[str, Any]): Dictionary to store various types of prompts for inference.\n\n Methods:\n get_model: Retrieve and initialize the SAM2 model.\n prompt_inference: Perform image segmentation inference based on various prompts.\n set_image: Preprocess and set a single image for inference.\n get_im_features: Extract and process image features using SAM2's image encoder.\n\n Examples:\n >>> predictor = SAM2Predictor(cfg)\n >>> predictor.set_image(\"path/to/image.jpg\")\n >>> bboxes = [[100, 100, 200, 200]]\n >>> result = predictor(bboxes=bboxes)[0]\n >>> print(f\"Predicted {len(result.masks)} masks with average score {result.boxes.conf.mean():.2f}\")\n \"\"\"\n\n _bb_feat_sizes = [\n (256, 256),\n (128, 128),\n (64, 64),\n ]\n\n def get_model(self):\n \"\"\"Retrieve and initialize the Segment Anything Model 2 (SAM2) for image segmentation tasks.\"\"\"\n from .build import build_sam # slow import\n\n return build_sam(self.args.model)\n\n def prompt_inference(\n self,\n im,\n bboxes=None,\n points=None,\n labels=None,\n masks=None,\n multimask_output=False,\n img_idx=-1,\n ):\n \"\"\"\n Perform image segmentation inference based on various prompts using SAM2 architecture.\n\n This method leverages the Segment Anything Model 2 (SAM2) to generate segmentation masks for input images\n based on provided prompts such as bounding boxes, points, or existing masks. It supports both single and\n multi-object prediction scenarios.\n\n Args:\n im (torch.Tensor): Preprocessed input image tensor with shape (N, C, H, W).\n bboxes (np.ndarray | List[List[float]] | None): Bounding boxes in XYXY format with shape (N, 4).\n points (np.ndarray | List[List[float]] | None): Object location points with shape (N, 2), in pixels.\n labels (np.ndarray | List[int] | None): Point prompt labels with shape (N,). 1 = foreground, 0 = background.\n masks (np.ndarray | None): Low-resolution masks from previous predictions with shape (N, H, W).\n multimask_output (bool): Flag to return multiple masks for ambiguous prompts.\n img_idx (int): Index of the image in the batch to process.\n\n Returns:\n pred_masks (np.ndarray): Output masks with shape (C, H, W), where C is the number of generated masks.\n pred_scores (np.ndarray): Quality scores for each mask, with length C.\n\n Examples:\n >>> predictor = SAM2Predictor(cfg)\n >>> image = torch.rand(1, 3, 640, 640)\n >>> bboxes = [[100, 100, 200, 200]]\n >>> result = predictor(image, bboxes=bboxes)[0]\n >>> print(f\"Generated {result.masks.shape[0]} masks with average score {result.boxes.conf.mean():.2f}\")\n\n Notes:\n - The method supports batched inference for multiple objects when points or bboxes are provided.\n - Input prompts (bboxes, points) are automatically scaled to match the input image dimensions.\n - When both bboxes and points are provided, they are merged into a single 'points' input for the model.\n \"\"\"\n features = self.get_im_features(im) if self.features is None else self.features\n\n points, labels, masks = self._prepare_prompts(im.shape[2:], bboxes, points, labels, masks)\n points = (points, labels) if points is not None else None\n\n sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder(\n points=points,\n boxes=None,\n masks=masks,\n )\n # Predict masks\n batched_mode = points is not None and points[0].shape[0] > 1 # multi object prediction\n high_res_features = [feat_level[img_idx].unsqueeze(0) for feat_level in features[\"high_res_feats\"]]\n pred_masks, pred_scores, _, _ = self.model.sam_mask_decoder(\n image_embeddings=features[\"image_embed\"][img_idx].unsqueeze(0),\n image_pe=self.model.sam_prompt_encoder.get_dense_pe(),\n sparse_prompt_embeddings=sparse_embeddings,\n dense_prompt_embeddings=dense_embeddings,\n multimask_output=multimask_output,\n repeat_image=batched_mode,\n high_res_features=high_res_features,\n )\n # (N, d, H, W) --> (N*d, H, W), (N, d) --> (N*d, )\n # `d` could be 1 or 3 depends on `multimask_output`.\n return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1)\n\n def _prepare_prompts(self, dst_shape, bboxes=None, points=None, labels=None, masks=None):\n \"\"\"\n Prepare and transform the input prompts for processing based on the destination shape.\n\n Args:\n dst_shape (tuple): The target shape (height, width) for the prompts.\n bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4).\n points (np.ndarray | List | None): Points indicating object locations with shape (N, 2) or (N, num_points, 2), in pixels.\n labels (np.ndarray | List | None): Point prompt labels with shape (N,) or (N, num_points). 1 for foreground, 0 for background.\n masks (List | np.ndarray | None): Masks for the objects, where each mask is a 2D array.\n\n Returns:\n points (torch.Tensor | None): Transformed points.\n labels (torch.Tensor | None): Transformed labels.\n masks (torch.Tensor | None): Transformed masks.\n\n Raises:\n AssertionError: If the number of points don't match the number of labels, in case labels were passed.\n \"\"\"\n bboxes, points, labels, masks = super()._prepare_prompts(dst_shape, bboxes, points, labels, masks)\n if bboxes is not None:\n bboxes = bboxes.view(-1, 2, 2)\n bbox_labels = torch.tensor([[2, 3]], dtype=torch.int32, device=bboxes.device).expand(len(bboxes), -1)\n # NOTE: merge \"boxes\" and \"points\" into a single \"points\" input\n # (where boxes are added at the beginning) to model.sam_prompt_encoder\n if points is not None:\n points = torch.cat([bboxes, points], dim=1)\n labels = torch.cat([bbox_labels, labels], dim=1)\n else:\n points, labels = bboxes, bbox_labels\n return points, labels, masks\n\n def set_image(self, image):\n \"\"\"\n Preprocess and set a single image for inference using the SAM2 model.\n\n This method initializes the model if not already done, configures the data source to the specified image,\n and preprocesses the image for feature extraction. It supports setting only one image at a time.\n\n Args:\n image (str | np.ndarray): Path to the image file as a string, or a numpy array representing the image.\n\n Examples:\n >>> predictor = SAM2Predictor()\n >>> predictor.set_image(\"path/to/image.jpg\")\n >>> predictor.set_image(np.array([...])) # Using a numpy array\n\n Raises:\n AssertionError: If more than one image is attempted to be set.\n\n Notes:\n - This method must be called before performing any inference on a new image.\n - The method caches the extracted features for efficient subsequent inferences on the same image.\n - Only one image can be set at a time. To process multiple images, call this method for each new image.\n \"\"\"\n if self.model is None:\n self.setup_model(model=None)\n self.setup_source(image)\n assert len(self.dataset) == 1, \"`set_image` only supports setting one image!\"\n for batch in self.dataset:\n im = self.preprocess(batch[1])\n self.features = self.get_im_features(im)\n break\n\n def get_im_features(self, im):\n \"\"\"Extract image features from the SAM image encoder for subsequent processing.\"\"\"\n assert isinstance(self.imgsz, (tuple, list)) and self.imgsz[0] == self.imgsz[1], (\n f\"SAM 2 models only support square image size, but got {self.imgsz}.\"\n )\n self.model.set_imgsz(self.imgsz)\n self._bb_feat_sizes = [[x // (4 * i) for x in self.imgsz] for i in [1, 2, 4]]\n\n backbone_out = self.model.forward_image(im)\n _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out)\n if self.model.directly_add_no_mem_embed:\n vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed\n feats = [\n feat.permute(1, 2, 0).view(1, -1, *feat_size)\n for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])\n ][::-1]\n return {\"image_embed\": feats[-1], \"high_res_feats\": feats[:-1]}", "chunk_type": "class", "name": "SAM2Predictor", "file_path": "ultralytics\\ultralytics\\models\\sam\\predict.py", "start_line": 624, "end_line": 814, "start_col": 0, "end_col": 71, "parent_name": null, "docstring": "SAM2Predictor class for advanced image segmentation using Segment Anything Model 2 architecture.\n\nThis class extends the base Predictor class to implement SAM2-specific functionality for image\nsegmentation tasks. It provides methods for model initialization, feature extraction, and\nprompt-based inference.\n\nAttributes:\n _bb_feat_sizes (List[tuple]): Feature sizes for different backbone levels.\n model (torch.nn.Module): The loaded SAM2 model.\n device (torch.device): The device (CPU or GPU) on which the model is loaded.\n features (dict): Cached image features for efficient inference.\n segment_all (bool): Flag to indicate if all segments should be predicted.\n prompts (Dict[str, Any]): Dictionary to store various types of prompts for inference.\n\nMethods:\n get_model: Retrieve and initialize the SAM2 model.\n prompt_inference: Perform image segmentation inference based on various prompts.\n set_image: Preprocess and set a single image for inference.\n get_im_features: Extract and process image features using SAM2's image encoder.\n\nExamples:\n >>> predictor = SAM2Predictor(cfg)\n >>> predictor.set_image(\"path/to/image.jpg\")\n >>> bboxes = [[100, 100, 200, 200]]\n >>> result = predictor(bboxes=bboxes)[0]\n >>> print(f\"Predicted {len(result.masks)} masks with average score {result.boxes.conf.mean():.2f}\")", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "collections.OrderedDict", "numpy", "torch", "torch.nn.functional", "ultralytics.data.augment.LetterBox", "ultralytics.engine.predictor.BasePredictor", "ultralytics.engine.results.Results", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.ops", "ultralytics.utils.torch_utils.select_device", "ultralytics.utils.torch_utils.smart_inference_mode", "amg.batch_iterator", "amg.batched_mask_to_box", "amg.build_all_layer_point_grids", "amg.calculate_stability_score", "amg.generate_crop_boxes", "amg.is_box_near_crop_edge", "amg.remove_small_regions", "amg.uncrop_boxes_xyxy", "amg.uncrop_masks", "torchvision", "build.build_sam", "torchvision", "build.build_sam", "Predictor" ], "chunk_id": "class_SAM2Predictor_0545c371" }, { "content": "class SAM2VideoPredictor(SAM2Predictor):\n \"\"\"\n SAM2VideoPredictor to handle user interactions with videos and manage inference states.\n\n This class extends the functionality of SAM2Predictor to support video processing and maintains\n the state of inference operations. It includes configurations for managing non-overlapping masks,\n clearing memory for non-conditional inputs, and setting up callbacks for prediction events.\n\n Attributes:\n inference_state (dict): A dictionary to store the current state of inference operations.\n non_overlap_masks (bool): A flag indicating whether masks should be non-overlapping.\n clear_non_cond_mem_around_input (bool): A flag to control clearing non-conditional memory around inputs.\n clear_non_cond_mem_for_multi_obj (bool): A flag to control clearing non-conditional memory for multi-object scenarios.\n callbacks (dict): A dictionary of callbacks for various prediction lifecycle events.\n\n Methods:\n get_model: Retrieve and configure the model with binarization enabled.\n inference: Perform image segmentation inference based on the given input cues.\n postprocess: Post-process the predictions to apply non-overlapping constraints if required.\n add_new_prompts: Add new points or masks to a specific frame for a given object ID.\n propagate_in_video_preflight: Prepare inference_state and consolidate temporary outputs before tracking.\n init_state: Initialize an inference state for the predictor.\n get_im_features: Extract and process image features using SAM2's image encoder for subsequent segmentation tasks.\n\n Examples:\n >>> predictor = SAM2VideoPredictor(cfg=DEFAULT_CFG)\n >>> predictor.set_image(\"path/to/video_frame.jpg\")\n >>> bboxes = [[100, 100, 200, 200]]\n >>> results = predictor(bboxes=bboxes)\n\n Note:\n The `fill_hole_area` attribute is defined but not used in the current implementation.\n \"\"\"\n\n # fill_hole_area = 8 # not used\n\n def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):\n \"\"\"\n Initialize the predictor with configuration and optional overrides.\n\n This constructor initializes the SAM2VideoPredictor with a given configuration, applies any\n specified overrides, and sets up the inference state along with certain flags\n that control the behavior of the predictor.\n\n Args:\n cfg (dict): Configuration dictionary containing default settings.\n overrides (dict | None): Dictionary of values to override default configuration.\n _callbacks (dict | None): Dictionary of callback functions to customize behavior.\n\n Examples:\n >>> predictor = SAM2VideoPredictor(cfg=DEFAULT_CFG)\n >>> predictor_example_with_imgsz = SAM2VideoPredictor(overrides={\"imgsz\": 640})\n >>> predictor_example_with_callback = SAM2VideoPredictor(_callbacks={\"on_predict_start\": custom_callback})\n \"\"\"\n super().__init__(cfg, overrides, _callbacks)\n self.inference_state = {}\n self.non_overlap_masks = True\n self.clear_non_cond_mem_around_input = False\n self.clear_non_cond_mem_for_multi_obj = False\n self.callbacks[\"on_predict_start\"].append(self.init_state)\n\n def get_model(self):\n \"\"\"\n Retrieve and configure the model with binarization enabled.\n\n Note:\n This method overrides the base class implementation to set the binarize flag to True.\n \"\"\"\n model = super().get_model()\n model.set_binarize(True)\n return model\n\n def inference(self, im, bboxes=None, points=None, labels=None, masks=None):\n \"\"\"\n Perform image segmentation inference based on the given input cues, using the currently loaded image. This\n method leverages SAM's (Segment Anything Model) architecture consisting of image encoder, prompt encoder, and\n mask decoder for real-time and promptable segmentation tasks.\n\n Args:\n im (torch.Tensor): The preprocessed input image in tensor format, with shape (N, C, H, W).\n bboxes (np.ndarray | List, optional): Bounding boxes with shape (N, 4), in XYXY format.\n points (np.ndarray | List, optional): Points indicating object locations with shape (N, 2), in pixels.\n labels (np.ndarray | List, optional): Labels for point prompts, shape (N, ). 1 = foreground, 0 = background.\n masks (np.ndarray, optional): Low-resolution masks from previous predictions shape (N,H,W). For SAM H=W=256.\n\n Returns:\n pred_masks (np.ndarray): The output masks in shape CxHxW, where C is the number of generated masks.\n pred_scores (np.ndarray): An array of length C containing quality scores predicted by the model for each mask.\n \"\"\"\n # Override prompts if any stored in self.prompts\n bboxes = self.prompts.pop(\"bboxes\", bboxes)\n points = self.prompts.pop(\"points\", points)\n masks = self.prompts.pop(\"masks\", masks)\n\n frame = self.dataset.frame\n self.inference_state[\"im\"] = im\n output_dict = self.inference_state[\"output_dict\"]\n if len(output_dict[\"cond_frame_outputs\"]) == 0: # initialize prompts\n points, labels, masks = self._prepare_prompts(im.shape[2:], bboxes, points, labels, masks)\n if points is not None:\n for i in range(len(points)):\n self.add_new_prompts(obj_id=i, points=points[[i]], labels=labels[[i]], frame_idx=frame)\n elif masks is not None:\n for i in range(len(masks)):\n self.add_new_prompts(obj_id=i, masks=masks[[i]], frame_idx=frame)\n self.propagate_in_video_preflight()\n\n consolidated_frame_inds = self.inference_state[\"consolidated_frame_inds\"]\n batch_size = len(self.inference_state[\"obj_idx_to_id\"])\n if len(output_dict[\"cond_frame_outputs\"]) == 0:\n raise RuntimeError(\"No points are provided; please add points first\")\n\n if frame in consolidated_frame_inds[\"cond_frame_outputs\"]:\n storage_key = \"cond_frame_outputs\"\n current_out = output_dict[storage_key][frame]\n if self.clear_non_cond_mem_around_input and (self.clear_non_cond_mem_for_multi_obj or batch_size <= 1):\n # clear non-conditioning memory of the surrounding frames\n self._clear_non_cond_mem_around_input(frame)\n elif frame in consolidated_frame_inds[\"non_cond_frame_outputs\"]:\n storage_key = \"non_cond_frame_outputs\"\n current_out = output_dict[storage_key][frame]\n else:\n storage_key = \"non_cond_frame_outputs\"\n current_out = self._run_single_frame_inference(\n output_dict=output_dict,\n frame_idx=frame,\n batch_size=batch_size,\n is_init_cond_frame=False,\n point_inputs=None,\n mask_inputs=None,\n reverse=False,\n run_mem_encoder=True,\n )\n output_dict[storage_key][frame] = current_out\n # Create slices of per-object outputs for subsequent interaction with each\n # individual object after tracking.\n self._add_output_per_object(frame, current_out, storage_key)\n self.inference_state[\"frames_already_tracked\"].append(frame)\n pred_masks = current_out[\"pred_masks\"].flatten(0, 1)\n pred_masks = pred_masks[(pred_masks > self.model.mask_threshold).sum((1, 2)) > 0] # filter blank masks\n\n return pred_masks, torch.ones(len(pred_masks), dtype=pred_masks.dtype, device=pred_masks.device)\n\n def postprocess(self, preds, img, orig_imgs):\n \"\"\"\n Post-process the predictions to apply non-overlapping constraints if required.\n\n This method extends the post-processing functionality by applying non-overlapping constraints\n to the predicted masks if the `non_overlap_masks` flag is set to True. This ensures that\n the masks do not overlap, which can be useful for certain applications.\n\n Args:\n preds (tuple): The predictions from the model.\n img (torch.Tensor): The processed image tensor.\n orig_imgs (List[np.ndarray]): The original images before processing.\n\n Returns:\n (list): The post-processed predictions.\n\n Note:\n If `non_overlap_masks` is True, the method applies constraints to ensure non-overlapping masks.\n \"\"\"\n results = super().postprocess(preds, img, orig_imgs)\n if self.non_overlap_masks:\n for result in results:\n if result.masks is None or len(result.masks) == 0:\n continue\n result.masks.data = self.model._apply_non_overlapping_constraints(result.masks.data.unsqueeze(0))[0]\n return results\n\n @smart_inference_mode()\n def add_new_prompts(\n self,\n obj_id,\n points=None,\n labels=None,\n masks=None,\n frame_idx=0,\n ):\n \"\"\"\n Add new points or masks to a specific frame for a given object ID.\n\n This method updates the inference state with new prompts (points or masks) for a specified\n object and frame index. It ensures that the prompts are either points or masks, but not both,\n and updates the internal state accordingly. It also handles the generation of new segmentations\n based on the provided prompts and the existing state.\n\n Args:\n obj_id (int): The ID of the object to which the prompts are associated.\n points (torch.Tensor, optional): The coordinates of the points of interest.\n labels (torch.Tensor, optional): The labels corresponding to the points.\n masks (torch.Tensor, optional): Binary masks for the object.\n frame_idx (int, optional): The index of the frame to which the prompts are applied.\n\n Returns:\n pred_masks (torch.Tensor): The flattened predicted masks.\n pred_scores (torch.Tensor): A tensor of ones indicating the number of objects.\n\n Raises:\n AssertionError: If both `masks` and `points` are provided, or neither is provided.\n\n Note:\n - Only one type of prompt (either points or masks) can be added per call.\n - If the frame is being tracked for the first time, it is treated as an initial conditioning frame.\n - The method handles the consolidation of outputs and resizing of masks to the original video resolution.\n \"\"\"\n assert (masks is None) ^ (points is None), \"'masks' and 'points' prompts are not compatible with each other.\"\n obj_idx = self._obj_id_to_idx(obj_id)\n\n point_inputs = None\n pop_key = \"point_inputs_per_obj\"\n if points is not None:\n point_inputs = {\"point_coords\": points, \"point_labels\": labels}\n self.inference_state[\"point_inputs_per_obj\"][obj_idx][frame_idx] = point_inputs\n pop_key = \"mask_inputs_per_obj\"\n self.inference_state[\"mask_inputs_per_obj\"][obj_idx][frame_idx] = masks\n self.inference_state[pop_key][obj_idx].pop(frame_idx, None)\n # If this frame hasn't been tracked before, we treat it as an initial conditioning\n # frame, meaning that the inputs points are to generate segments on this frame without\n # using any memory from other frames, like in SAM. Otherwise (if it has been tracked),\n # the input points will be used to correct the already tracked masks.\n is_init_cond_frame = frame_idx not in self.inference_state[\"frames_already_tracked\"]\n obj_output_dict = self.inference_state[\"output_dict_per_obj\"][obj_idx]\n obj_temp_output_dict = self.inference_state[\"temp_output_dict_per_obj\"][obj_idx]\n # Add a frame to conditioning output if it's an initial conditioning frame or\n # if the model sees all frames receiving clicks/mask as conditioning frames.\n is_cond = is_init_cond_frame or self.model.add_all_frames_to_correct_as_cond\n storage_key = \"cond_frame_outputs\" if is_cond else \"non_cond_frame_outputs\"\n\n # Get any previously predicted mask logits on this object and feed it along with\n # the new clicks into the SAM mask decoder.\n prev_sam_mask_logits = None\n # lookup temporary output dict first, which contains the most recent output\n # (if not found, then lookup conditioning and non-conditioning frame output)\n if point_inputs is not None:\n prev_out = (\n obj_temp_output_dict[storage_key].get(frame_idx)\n or obj_output_dict[\"cond_frame_outputs\"].get(frame_idx)\n or obj_output_dict[\"non_cond_frame_outputs\"].get(frame_idx)\n )\n\n if prev_out is not None and prev_out.get(\"pred_masks\") is not None:\n prev_sam_mask_logits = prev_out[\"pred_masks\"].to(device=self.device, non_blocking=True)\n # Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues.\n prev_sam_mask_logits.clamp_(-32.0, 32.0)\n current_out = self._run_single_frame_inference(\n output_dict=obj_output_dict, # run on the slice of a single object\n frame_idx=frame_idx,\n batch_size=1, # run on the slice of a single object\n is_init_cond_frame=is_init_cond_frame,\n point_inputs=point_inputs,\n mask_inputs=masks,\n reverse=False,\n # Skip the memory encoder when adding clicks or mask. We execute the memory encoder\n # at the beginning of `propagate_in_video` (after user finalize their clicks). This\n # allows us to enforce non-overlapping constraints on all objects before encoding\n # them into memory.\n run_mem_encoder=False,\n prev_sam_mask_logits=prev_sam_mask_logits,\n )\n # Add the output to the output dict (to be used as future memory)\n obj_temp_output_dict[storage_key][frame_idx] = current_out\n\n # Resize the output mask to the original video resolution\n consolidated_out = self._consolidate_temp_output_across_obj(\n frame_idx,\n is_cond=is_cond,\n run_mem_encoder=False,\n )\n pred_masks = consolidated_out[\"pred_masks\"].flatten(0, 1)\n return pred_masks.flatten(0, 1), torch.ones(1, dtype=pred_masks.dtype, device=pred_masks.device)\n\n @smart_inference_mode()\n def propagate_in_video_preflight(self):\n \"\"\"\n Prepare inference_state and consolidate temporary outputs before tracking.\n\n This method marks the start of tracking, disallowing the addition of new objects until the session is reset.\n It consolidates temporary outputs from `temp_output_dict_per_obj` and merges them into `output_dict`.\n Additionally, it clears non-conditioning memory around input frames and ensures that the state is consistent\n with the provided inputs.\n \"\"\"\n # Tracking has started and we don't allow adding new objects until session is reset.\n self.inference_state[\"tracking_has_started\"] = True\n batch_size = len(self.inference_state[\"obj_idx_to_id\"])\n\n # Consolidate per-object temporary outputs in \"temp_output_dict_per_obj\" and\n # add them into \"output_dict\".\n temp_output_dict_per_obj = self.inference_state[\"temp_output_dict_per_obj\"]\n output_dict = self.inference_state[\"output_dict\"]\n # \"consolidated_frame_inds\" contains indices of those frames where consolidated\n # temporary outputs have been added (either in this call or any previous calls\n # to `propagate_in_video_preflight`).\n consolidated_frame_inds = self.inference_state[\"consolidated_frame_inds\"]\n for is_cond in {False, True}:\n # Separately consolidate conditioning and non-conditioning temp outputs\n storage_key = \"cond_frame_outputs\" if is_cond else \"non_cond_frame_outputs\"\n # Find all the frames that contain temporary outputs for any objects\n # (these should be the frames that have just received clicks for mask inputs\n # via `add_new_points` or `add_new_mask`)\n temp_frame_inds = set()\n for obj_temp_output_dict in temp_output_dict_per_obj.values():\n temp_frame_inds.update(obj_temp_output_dict[storage_key].keys())\n consolidated_frame_inds[storage_key].update(temp_frame_inds)\n # consolidate the temporary output across all objects on this frame\n for frame_idx in temp_frame_inds:\n consolidated_out = self._consolidate_temp_output_across_obj(\n frame_idx, is_cond=is_cond, run_mem_encoder=True\n )\n # merge them into \"output_dict\" and also create per-object slices\n output_dict[storage_key][frame_idx] = consolidated_out\n self._add_output_per_object(frame_idx, consolidated_out, storage_key)\n if self.clear_non_cond_mem_around_input and (self.clear_non_cond_mem_for_multi_obj or batch_size <= 1):\n # clear non-conditioning memory of the surrounding frames\n self._clear_non_cond_mem_around_input(frame_idx)\n\n # clear temporary outputs in `temp_output_dict_per_obj`\n for obj_temp_output_dict in temp_output_dict_per_obj.values():\n obj_temp_output_dict[storage_key].clear()\n\n # edge case: if an output is added to \"cond_frame_outputs\", we remove any prior\n # output on the same frame in \"non_cond_frame_outputs\"\n for frame_idx in output_dict[\"cond_frame_outputs\"]:\n output_dict[\"non_cond_frame_outputs\"].pop(frame_idx, None)\n for obj_output_dict in self.inference_state[\"output_dict_per_obj\"].values():\n for frame_idx in obj_output_dict[\"cond_frame_outputs\"]:\n obj_output_dict[\"non_cond_frame_outputs\"].pop(frame_idx, None)\n for frame_idx in consolidated_frame_inds[\"cond_frame_outputs\"]:\n assert frame_idx in output_dict[\"cond_frame_outputs\"]\n consolidated_frame_inds[\"non_cond_frame_outputs\"].discard(frame_idx)\n\n # Make sure that the frame indices in \"consolidated_frame_inds\" are exactly those frames\n # with either points or mask inputs (which should be true under a correct workflow).\n all_consolidated_frame_inds = (\n consolidated_frame_inds[\"cond_frame_outputs\"] | consolidated_frame_inds[\"non_cond_frame_outputs\"]\n )\n input_frames_inds = set()\n for point_inputs_per_frame in self.inference_state[\"point_inputs_per_obj\"].values():\n input_frames_inds.update(point_inputs_per_frame.keys())\n for mask_inputs_per_frame in self.inference_state[\"mask_inputs_per_obj\"].values():\n input_frames_inds.update(mask_inputs_per_frame.keys())\n assert all_consolidated_frame_inds == input_frames_inds\n\n @staticmethod\n def init_state(predictor):\n \"\"\"\n Initialize an inference state for the predictor.\n\n This function sets up the initial state required for performing inference on video data.\n It includes initializing various dictionaries and ordered dictionaries that will store\n inputs, outputs, and other metadata relevant to the tracking process.\n\n Args:\n predictor (SAM2VideoPredictor): The predictor object for which to initialize the state.\n \"\"\"\n if len(predictor.inference_state) > 0: # means initialized\n return\n assert predictor.dataset is not None\n assert predictor.dataset.mode == \"video\"\n\n inference_state = {\n \"num_frames\": predictor.dataset.frames,\n \"point_inputs_per_obj\": {}, # inputs points on each frame\n \"mask_inputs_per_obj\": {}, # inputs mask on each frame\n \"constants\": {}, # values that don't change across frames (so we only need to hold one copy of them)\n # mapping between client-side object id and model-side object index\n \"obj_id_to_idx\": OrderedDict(),\n \"obj_idx_to_id\": OrderedDict(),\n \"obj_ids\": [],\n # A storage to hold the model's tracking results and states on each frame\n \"output_dict\": {\n \"cond_frame_outputs\": {}, # dict containing {frame_idx: }\n \"non_cond_frame_outputs\": {}, # dict containing {frame_idx: }\n },\n # Slice (view) of each object tracking results, sharing the same memory with \"output_dict\"\n \"output_dict_per_obj\": {},\n # A temporary storage to hold new outputs when user interact with a frame\n # to add clicks or mask (it's merged into \"output_dict\" before propagation starts)\n \"temp_output_dict_per_obj\": {},\n # Frames that already holds consolidated outputs from click or mask inputs\n # (we directly use their consolidated outputs during tracking)\n \"consolidated_frame_inds\": {\n \"cond_frame_outputs\": set(), # set containing frame indices\n \"non_cond_frame_outputs\": set(), # set containing frame indices\n },\n # metadata for each tracking frame (e.g. which direction it's tracked)\n \"tracking_has_started\": False,\n \"frames_already_tracked\": [],\n }\n predictor.inference_state = inference_state\n\n def get_im_features(self, im, batch=1):\n \"\"\"\n Extract and process image features using SAM2's image encoder for subsequent segmentation tasks.\n\n Args:\n im (torch.Tensor): The input image tensor.\n batch (int, optional): The batch size for expanding features if there are multiple prompts.\n\n Returns:\n vis_feats (torch.Tensor): The visual features extracted from the image.\n vis_pos_embed (torch.Tensor): The positional embeddings for the visual features.\n feat_sizes (List[tuple]): A list containing the sizes of the extracted features.\n\n Note:\n - If `batch` is greater than 1, the features are expanded to fit the batch size.\n - The method leverages the model's `_prepare_backbone_features` method to prepare the backbone features.\n \"\"\"\n backbone_out = self.model.forward_image(im)\n if batch > 1: # expand features if there's more than one prompt\n for i, feat in enumerate(backbone_out[\"backbone_fpn\"]):\n backbone_out[\"backbone_fpn\"][i] = feat.expand(batch, -1, -1, -1)\n for i, pos in enumerate(backbone_out[\"vision_pos_enc\"]):\n pos = pos.expand(batch, -1, -1, -1)\n backbone_out[\"vision_pos_enc\"][i] = pos\n _, vis_feats, vis_pos_embed, feat_sizes = self.model._prepare_backbone_features(backbone_out)\n return vis_feats, vis_pos_embed, feat_sizes\n\n def _obj_id_to_idx(self, obj_id):\n \"\"\"\n Map client-side object id to model-side object index.\n\n Args:\n obj_id (int): The unique identifier of the object provided by the client side.\n\n Returns:\n (int): The index of the object on the model side.\n\n Raises:\n RuntimeError: If an attempt is made to add a new object after tracking has started.\n\n Note:\n - The method updates or retrieves mappings between object IDs and indices stored in\n `inference_state`.\n - It ensures that new objects can only be added before tracking commences.\n - It maintains two-way mappings between IDs and indices (`obj_id_to_idx` and `obj_idx_to_id`).\n - Additional data structures are initialized for the new object to store inputs and outputs.\n \"\"\"\n obj_idx = self.inference_state[\"obj_id_to_idx\"].get(obj_id, None)\n if obj_idx is not None:\n return obj_idx\n\n # This is a new object id not sent to the server before. We only allow adding\n # new objects *before* the tracking starts.\n allow_new_object = not self.inference_state[\"tracking_has_started\"]\n if allow_new_object:\n # get the next object slot\n obj_idx = len(self.inference_state[\"obj_id_to_idx\"])\n self.inference_state[\"obj_id_to_idx\"][obj_id] = obj_idx\n self.inference_state[\"obj_idx_to_id\"][obj_idx] = obj_id\n self.inference_state[\"obj_ids\"] = list(self.inference_state[\"obj_id_to_idx\"])\n # set up input and output structures for this object\n self.inference_state[\"point_inputs_per_obj\"][obj_idx] = {}\n self.inference_state[\"mask_inputs_per_obj\"][obj_idx] = {}\n self.inference_state[\"output_dict_per_obj\"][obj_idx] = {\n \"cond_frame_outputs\": {}, # dict containing {frame_idx: }\n \"non_cond_frame_outputs\": {}, # dict containing {frame_idx: }\n }\n self.inference_state[\"temp_output_dict_per_obj\"][obj_idx] = {\n \"cond_frame_outputs\": {}, # dict containing {frame_idx: }\n \"non_cond_frame_outputs\": {}, # dict containing {frame_idx: }\n }\n return obj_idx\n else:\n raise RuntimeError(\n f\"Cannot add new object id {obj_id} after tracking starts. \"\n f\"All existing object ids: {self.inference_state['obj_ids']}. \"\n f\"Please call 'reset_state' to restart from scratch.\"\n )\n\n def _run_single_frame_inference(\n self,\n output_dict,\n frame_idx,\n batch_size,\n is_init_cond_frame,\n point_inputs,\n mask_inputs,\n reverse,\n run_mem_encoder,\n prev_sam_mask_logits=None,\n ):\n \"\"\"\n Run tracking on a single frame based on current inputs and previous memory.\n\n Args:\n output_dict (dict): The dictionary containing the output states of the tracking process.\n frame_idx (int): The index of the current frame.\n batch_size (int): The batch size for processing the frame.\n is_init_cond_frame (bool): Indicates if the current frame is an initial conditioning frame.\n point_inputs (dict | None): Input points and their labels.\n mask_inputs (torch.Tensor | None): Input binary masks.\n reverse (bool): Indicates if the tracking should be performed in reverse order.\n run_mem_encoder (bool): Indicates if the memory encoder should be executed.\n prev_sam_mask_logits (torch.Tensor | None): Previous mask logits for the current object.\n\n Returns:\n (dict): A dictionary containing the output of the tracking step, including updated features and predictions.\n\n Raises:\n AssertionError: If both `point_inputs` and `mask_inputs` are provided, or neither is provided.\n\n Note:\n - The method assumes that `point_inputs` and `mask_inputs` are mutually exclusive.\n - The method retrieves image features using the `get_im_features` method.\n - The `maskmem_pos_enc` is assumed to be constant across frames, hence only one copy is stored.\n - The `fill_holes_in_mask_scores` function is commented out and currently unsupported due to CUDA extension requirements.\n \"\"\"\n # Retrieve correct image features\n current_vision_feats, current_vision_pos_embeds, feat_sizes = self.get_im_features(\n self.inference_state[\"im\"], batch_size\n )\n\n # point and mask should not appear as input simultaneously on the same frame\n assert point_inputs is None or mask_inputs is None\n current_out = self.model.track_step(\n frame_idx=frame_idx,\n is_init_cond_frame=is_init_cond_frame,\n current_vision_feats=current_vision_feats,\n current_vision_pos_embeds=current_vision_pos_embeds,\n feat_sizes=feat_sizes,\n point_inputs=point_inputs,\n mask_inputs=mask_inputs,\n output_dict=output_dict,\n num_frames=self.inference_state[\"num_frames\"],\n track_in_reverse=reverse,\n run_mem_encoder=run_mem_encoder,\n prev_sam_mask_logits=prev_sam_mask_logits,\n )\n\n maskmem_features = current_out[\"maskmem_features\"]\n if maskmem_features is not None:\n current_out[\"maskmem_features\"] = maskmem_features.to(\n dtype=torch.float16, device=self.device, non_blocking=True\n )\n # NOTE: Do not support the `fill_holes_in_mask_scores` function since it needs cuda extensions\n # potentially fill holes in the predicted masks\n # if self.fill_hole_area > 0:\n # pred_masks = current_out[\"pred_masks\"].to(self.device, non_blocking=True)\n # pred_masks = fill_holes_in_mask_scores(pred_masks, self.fill_hole_area)\n\n # \"maskmem_pos_enc\" is the same across frames, so we only need to store one copy of it\n current_out[\"maskmem_pos_enc\"] = self._get_maskmem_pos_enc(current_out[\"maskmem_pos_enc\"])\n return current_out\n\n def _get_maskmem_pos_enc(self, out_maskmem_pos_enc):\n \"\"\"\n Cache and manage the positional encoding for mask memory across frames and objects.\n\n This method optimizes storage by caching the positional encoding (`maskmem_pos_enc`) for\n mask memory, which is constant across frames and objects, thus reducing the amount of\n redundant information stored during an inference session. It checks if the positional\n encoding has already been cached; if not, it caches a slice of the provided encoding.\n If the batch size is greater than one, it expands the cached positional encoding to match\n the current batch size.\n\n Args:\n out_maskmem_pos_enc (List[torch.Tensor] | None): The positional encoding for mask memory.\n Should be a list of tensors or None.\n\n Returns:\n (List[torch.Tensor]): The positional encoding for mask memory, either cached or expanded.\n\n Note:\n - The method assumes that `out_maskmem_pos_enc` is a list of tensors or None.\n - Only a single object's slice is cached since the encoding is the same across objects.\n - The method checks if the positional encoding has already been cached in the session's constants.\n - If the batch size is greater than one, the cached encoding is expanded to fit the batch size.\n \"\"\"\n model_constants = self.inference_state[\"constants\"]\n # \"out_maskmem_pos_enc\" should be either a list of tensors or None\n if out_maskmem_pos_enc is not None:\n if \"maskmem_pos_enc\" not in model_constants:\n assert isinstance(out_maskmem_pos_enc, list)\n # only take the slice for one object, since it's same across objects\n maskmem_pos_enc = [x[:1].clone() for x in out_maskmem_pos_enc]\n model_constants[\"maskmem_pos_enc\"] = maskmem_pos_enc\n else:\n maskmem_pos_enc = model_constants[\"maskmem_pos_enc\"]\n # expand the cached maskmem_pos_enc to the actual batch size\n batch_size = out_maskmem_pos_enc[0].size(0)\n if batch_size > 1:\n out_maskmem_pos_enc = [x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc]\n return out_maskmem_pos_enc\n\n def _consolidate_temp_output_across_obj(\n self,\n frame_idx,\n is_cond=False,\n run_mem_encoder=False,\n ):\n \"\"\"\n Consolidate per-object temporary outputs into a single output for all objects.\n\n This method combines the temporary outputs for each object on a given frame into a unified\n output. It fills in any missing objects either from the main output dictionary or leaves\n placeholders if they do not exist in the main output. Optionally, it can re-run the memory\n encoder after applying non-overlapping constraints to the object scores.\n\n Args:\n frame_idx (int): The index of the frame for which to consolidate outputs.\n is_cond (bool, optional): Indicates if the frame is considered a conditioning frame.\n run_mem_encoder (bool, optional): Specifies whether to run the memory encoder after\n consolidating the outputs.\n\n Returns:\n (dict): A consolidated output dictionary containing the combined results for all objects.\n\n Note:\n - The method initializes the consolidated output with placeholder values for missing objects.\n - It searches for outputs in both the temporary and main output dictionaries.\n - If `run_mem_encoder` is True, it applies non-overlapping constraints and re-runs the memory encoder.\n - The `maskmem_features` and `maskmem_pos_enc` are only populated when `run_mem_encoder` is True.\n \"\"\"\n batch_size = len(self.inference_state[\"obj_idx_to_id\"])\n storage_key = \"cond_frame_outputs\" if is_cond else \"non_cond_frame_outputs\"\n\n # Initialize `consolidated_out`. Its \"maskmem_features\" and \"maskmem_pos_enc\"\n # will be added when rerunning the memory encoder after applying non-overlapping\n # constraints to object scores. Its \"pred_masks\" are prefilled with a large\n # negative value (NO_OBJ_SCORE) to represent missing objects.\n consolidated_out = {\n \"maskmem_features\": None,\n \"maskmem_pos_enc\": None,\n \"pred_masks\": torch.full(\n size=(batch_size, 1, self.imgsz[0] // 4, self.imgsz[1] // 4),\n fill_value=-1024.0,\n dtype=torch.float32,\n device=self.device,\n ),\n \"obj_ptr\": torch.full(\n size=(batch_size, self.model.hidden_dim),\n fill_value=-1024.0,\n dtype=torch.float32,\n device=self.device,\n ),\n \"object_score_logits\": torch.full(\n size=(batch_size, 1),\n # default to 10.0 for object_score_logits, i.e. assuming the object is\n # present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder`\n fill_value=10.0,\n dtype=torch.float32,\n device=self.device,\n ),\n }\n for obj_idx in range(batch_size):\n obj_temp_output_dict = self.inference_state[\"temp_output_dict_per_obj\"][obj_idx]\n obj_output_dict = self.inference_state[\"output_dict_per_obj\"][obj_idx]\n out = (\n obj_temp_output_dict[storage_key].get(frame_idx)\n # If the object doesn't appear in \"temp_output_dict_per_obj\" on this frame,\n # we fall back and look up its previous output in \"output_dict_per_obj\".\n # We look up both \"cond_frame_outputs\" and \"non_cond_frame_outputs\" in\n # \"output_dict_per_obj\" to find a previous output for this object.\n or obj_output_dict[\"cond_frame_outputs\"].get(frame_idx)\n or obj_output_dict[\"non_cond_frame_outputs\"].get(frame_idx)\n )\n # If the object doesn't appear in \"output_dict_per_obj\" either, we skip it\n # and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE\n # placeholder above) and set its object pointer to be a dummy pointer.\n if out is None:\n # Fill in dummy object pointers for those objects without any inputs or\n # tracking outcomes on this frame (only do it under `run_mem_encoder=True`,\n # i.e. when we need to build the memory for tracking).\n if run_mem_encoder:\n # fill object pointer with a dummy pointer (based on an empty mask)\n consolidated_out[\"obj_ptr\"][obj_idx : obj_idx + 1] = self._get_empty_mask_ptr(frame_idx)\n continue\n # Add the temporary object output mask to consolidated output mask\n consolidated_out[\"pred_masks\"][obj_idx : obj_idx + 1] = out[\"pred_masks\"]\n consolidated_out[\"obj_ptr\"][obj_idx : obj_idx + 1] = out[\"obj_ptr\"]\n\n # Optionally, apply non-overlapping constraints on the consolidated scores and rerun the memory encoder\n if run_mem_encoder:\n high_res_masks = F.interpolate(\n consolidated_out[\"pred_masks\"],\n size=self.imgsz,\n mode=\"bilinear\",\n align_corners=False,\n )\n if self.model.non_overlap_masks_for_mem_enc:\n high_res_masks = self.model._apply_non_overlapping_constraints(high_res_masks)\n consolidated_out[\"maskmem_features\"], consolidated_out[\"maskmem_pos_enc\"] = self._run_memory_encoder(\n batch_size=batch_size,\n high_res_masks=high_res_masks,\n is_mask_from_pts=True, # these frames are what the user interacted with\n object_score_logits=consolidated_out[\"object_score_logits\"],\n )\n\n return consolidated_out\n\n def _get_empty_mask_ptr(self, frame_idx):\n \"\"\"\n Get a dummy object pointer based on an empty mask on the current frame.\n\n Args:\n frame_idx (int): The index of the current frame for which to generate the dummy object pointer.\n\n Returns:\n (torch.Tensor): A tensor representing the dummy object pointer generated from the empty mask.\n \"\"\"\n # Retrieve correct image features\n current_vision_feats, current_vision_pos_embeds, feat_sizes = self.get_im_features(self.inference_state[\"im\"])\n\n # Feed the empty mask and image feature above to get a dummy object pointer\n current_out = self.model.track_step(\n frame_idx=frame_idx,\n is_init_cond_frame=True,\n current_vision_feats=current_vision_feats,\n current_vision_pos_embeds=current_vision_pos_embeds,\n feat_sizes=feat_sizes,\n point_inputs=None,\n # A dummy (empty) mask with a single object\n mask_inputs=torch.zeros((1, 1, *self.imgsz), dtype=torch.float32, device=self.device),\n output_dict={},\n num_frames=self.inference_state[\"num_frames\"],\n track_in_reverse=False,\n run_mem_encoder=False,\n prev_sam_mask_logits=None,\n )\n return current_out[\"obj_ptr\"]\n\n def _run_memory_encoder(self, batch_size, high_res_masks, object_score_logits, is_mask_from_pts):\n \"\"\"\n Run the memory encoder on masks.\n\n This is usually after applying non-overlapping constraints to object scores. Since their scores changed, their\n memory also needs to be computed again with the memory encoder.\n\n Args:\n batch_size (int): The batch size for processing the frame.\n high_res_masks (torch.Tensor): High-resolution masks for which to compute the memory.\n object_score_logits (torch.Tensor): Logits representing the object scores.\n is_mask_from_pts (bool): Indicates if the mask is derived from point interactions.\n\n Returns:\n maskmem_features (torch.Tensor): The encoded mask features.\n maskmem_pos_enc (torch.Tensor): The positional encoding.\n \"\"\"\n # Retrieve correct image features\n current_vision_feats, _, feat_sizes = self.get_im_features(self.inference_state[\"im\"], batch_size)\n maskmem_features, maskmem_pos_enc = self.model._encode_new_memory(\n current_vision_feats=current_vision_feats,\n feat_sizes=feat_sizes,\n pred_masks_high_res=high_res_masks,\n is_mask_from_pts=is_mask_from_pts,\n object_score_logits=object_score_logits,\n )\n\n # \"maskmem_pos_enc\" is the same across frames, so we only need to store one copy of it\n maskmem_pos_enc = self._get_maskmem_pos_enc(maskmem_pos_enc)\n return maskmem_features.to(dtype=torch.float16, device=self.device, non_blocking=True), maskmem_pos_enc\n\n def _add_output_per_object(self, frame_idx, current_out, storage_key):\n \"\"\"\n Split a multi-object output into per-object output slices and add them into Output_Dict_Per_Obj.\n\n The resulting slices share the same tensor storage.\n\n Args:\n frame_idx (int): The index of the current frame.\n current_out (dict): The current output dictionary containing multi-object outputs.\n storage_key (str): The key used to store the output in the per-object output dictionary.\n \"\"\"\n maskmem_features = current_out[\"maskmem_features\"]\n assert maskmem_features is None or isinstance(maskmem_features, torch.Tensor)\n\n maskmem_pos_enc = current_out[\"maskmem_pos_enc\"]\n assert maskmem_pos_enc is None or isinstance(maskmem_pos_enc, list)\n\n for obj_idx, obj_output_dict in self.inference_state[\"output_dict_per_obj\"].items():\n obj_slice = slice(obj_idx, obj_idx + 1)\n obj_out = {\n \"maskmem_features\": None,\n \"maskmem_pos_enc\": None,\n \"pred_masks\": current_out[\"pred_masks\"][obj_slice],\n \"obj_ptr\": current_out[\"obj_ptr\"][obj_slice],\n }\n if maskmem_features is not None:\n obj_out[\"maskmem_features\"] = maskmem_features[obj_slice]\n if maskmem_pos_enc is not None:\n obj_out[\"maskmem_pos_enc\"] = [x[obj_slice] for x in maskmem_pos_enc]\n obj_output_dict[storage_key][frame_idx] = obj_out\n\n def _clear_non_cond_mem_around_input(self, frame_idx):\n \"\"\"\n Remove the non-conditioning memory around the input frame.\n\n When users provide correction clicks, the surrounding frames' non-conditioning memories can still contain outdated\n object appearance information and could confuse the model. This method clears those non-conditioning memories\n surrounding the interacted frame to avoid giving the model both old and new information about the object.\n\n Args:\n frame_idx (int): The index of the current frame where user interaction occurred.\n \"\"\"\n r = self.model.memory_temporal_stride_for_eval\n frame_idx_begin = frame_idx - r * self.model.num_maskmem\n frame_idx_end = frame_idx + r * self.model.num_maskmem\n for t in range(frame_idx_begin, frame_idx_end + 1):\n self.inference_state[\"output_dict\"][\"non_cond_frame_outputs\"].pop(t, None)\n for obj_output_dict in self.inference_state[\"output_dict_per_obj\"].values():\n obj_output_dict[\"non_cond_frame_outputs\"].pop(t, None)", "chunk_type": "class", "name": "SAM2VideoPredictor", "file_path": "ultralytics\\ultralytics\\models\\sam\\predict.py", "start_line": 817, "end_line": 1618, "start_col": 0, "end_col": 70, "parent_name": null, "docstring": "SAM2VideoPredictor to handle user interactions with videos and manage inference states.\n\nThis class extends the functionality of SAM2Predictor to support video processing and maintains\nthe state of inference operations. It includes configurations for managing non-overlapping masks,\nclearing memory for non-conditional inputs, and setting up callbacks for prediction events.\n\nAttributes:\n inference_state (dict): A dictionary to store the current state of inference operations.\n non_overlap_masks (bool): A flag indicating whether masks should be non-overlapping.\n clear_non_cond_mem_around_input (bool): A flag to control clearing non-conditional memory around inputs.\n clear_non_cond_mem_for_multi_obj (bool): A flag to control clearing non-conditional memory for multi-object scenarios.\n callbacks (dict): A dictionary of callbacks for various prediction lifecycle events.\n\nMethods:\n get_model: Retrieve and configure the model with binarization enabled.\n inference: Perform image segmentation inference based on the given input cues.\n postprocess: Post-process the predictions to apply non-overlapping constraints if required.\n add_new_prompts: Add new points or masks to a specific frame for a given object ID.\n propagate_in_video_preflight: Prepare inference_state and consolidate temporary outputs before tracking.\n init_state: Initialize an inference state for the predictor.\n get_im_features: Extract and process image features using SAM2's image encoder for subsequent segmentation tasks.\n\nExamples:\n >>> predictor = SAM2VideoPredictor(cfg=DEFAULT_CFG)\n >>> predictor.set_image(\"path/to/video_frame.jpg\")\n >>> bboxes = [[100, 100, 200, 200]]\n >>> results = predictor(bboxes=bboxes)\n\nNote:\n The `fill_hole_area` attribute is defined but not used in the current implementation.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "collections.OrderedDict", "numpy", "torch", "torch.nn.functional", "ultralytics.data.augment.LetterBox", "ultralytics.engine.predictor.BasePredictor", "ultralytics.engine.results.Results", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.ops", "ultralytics.utils.torch_utils.select_device", "ultralytics.utils.torch_utils.smart_inference_mode", "amg.batch_iterator", "amg.batched_mask_to_box", "amg.build_all_layer_point_grids", "amg.calculate_stability_score", "amg.generate_crop_boxes", "amg.is_box_near_crop_edge", "amg.remove_small_regions", "amg.uncrop_boxes_xyxy", "amg.uncrop_masks", "torchvision", "build.build_sam", "torchvision", "build.build_sam", "SAM2Predictor" ], "chunk_id": "class_SAM2VideoPredictor_4656fcbd" }, { "content": "from .model import SAM", "chunk_type": "import", "name": "SAM", "file_path": "ultralytics\\ultralytics\\models\\sam\\__init__.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 22, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_SAM_fac33890" }, { "content": "from .predict import Predictor, SAM2Predictor, SAM2VideoPredictor", "chunk_type": "import", "name": "Predictor, SAM2Predictor, SAM2VideoPredictor", "file_path": "ultralytics\\ultralytics\\models\\sam\\__init__.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 65, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Predictor, SAM2Predictor, SAM2VideoPredictor_61c8aadd" }, { "content": "__all__ = \"SAM\", \"Predictor\", \"SAM2Predictor\", \"SAM2VideoPredictor\" # tuple or list of exportable items", "chunk_type": "variable", "name": "__all__", "file_path": "ultralytics\\ultralytics\\models\\sam\\__init__.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 67, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable___all___9af0d37e" }, { "content": "from typing import Any, Dict, List, Optional, Tuple", "chunk_type": "import", "name": "Any, Dict, List, Optional, Tuple", "file_path": "ultralytics\\ultralytics\\models\\utils\\loss.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 51, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Any, Dict, List, Optional, Tuple_2395c2be" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\ultralytics\\models\\utils\\loss.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_49f8c258" }, { "content": "import torch.nn as nn", "chunk_type": "import", "name": "torch.nn", "file_path": "ultralytics\\ultralytics\\models\\utils\\loss.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 21, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch.nn_596b42cc" }, { "content": "import torch.nn.functional as F", "chunk_type": "import", "name": "torch.nn.functional", "file_path": "ultralytics\\ultralytics\\models\\utils\\loss.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 31, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch.nn.functional_dc48fad9" }, { "content": "from ultralytics.utils.loss import FocalLoss, VarifocalLoss", "chunk_type": "import", "name": "FocalLoss, VarifocalLoss", "file_path": "ultralytics\\ultralytics\\models\\utils\\loss.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 59, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_FocalLoss, VarifocalLoss_a255af35" }, { "content": "from ultralytics.utils.metrics import bbox_iou", "chunk_type": "import", "name": "bbox_iou", "file_path": "ultralytics\\ultralytics\\models\\utils\\loss.py", "start_line": 10, "end_line": 10, "start_col": 0, "end_col": 46, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_bbox_iou_3e0995ff" }, { "content": "from .ops import HungarianMatcher", "chunk_type": "import", "name": "HungarianMatcher", "file_path": "ultralytics\\ultralytics\\models\\utils\\loss.py", "start_line": 12, "end_line": 12, "start_col": 0, "end_col": 33, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_HungarianMatcher_a687dd40" }, { "content": "class DETRLoss(nn.Module):\n \"\"\"\n DETR (DEtection TRansformer) Loss class for calculating various loss components.\n\n This class computes classification loss, bounding box loss, GIoU loss, and optionally auxiliary losses for the\n DETR object detection model.\n\n Attributes:\n nc (int): Number of classes.\n loss_gain (Dict[str, float]): Coefficients for different loss components.\n aux_loss (bool): Whether to compute auxiliary losses.\n use_fl (bool): Whether to use FocalLoss.\n use_vfl (bool): Whether to use VarifocalLoss.\n use_uni_match (bool): Whether to use a fixed layer for auxiliary branch label assignment.\n uni_match_ind (int): Index of fixed layer to use if use_uni_match is True.\n matcher (HungarianMatcher): Object to compute matching cost and indices.\n fl (FocalLoss | None): Focal Loss object if use_fl is True, otherwise None.\n vfl (VarifocalLoss | None): Varifocal Loss object if use_vfl is True, otherwise None.\n device (torch.device): Device on which tensors are stored.\n \"\"\"\n\n def __init__(\n self,\n nc: int = 80,\n loss_gain: Optional[Dict[str, float]] = None,\n aux_loss: bool = True,\n use_fl: bool = True,\n use_vfl: bool = False,\n use_uni_match: bool = False,\n uni_match_ind: int = 0,\n gamma: float = 1.5,\n alpha: float = 0.25,\n ):\n \"\"\"\n Initialize DETR loss function with customizable components and gains.\n\n Uses default loss_gain if not provided. Initializes HungarianMatcher with preset cost gains. Supports auxiliary\n losses and various loss types.\n\n Args:\n nc (int): Number of classes.\n loss_gain (Dict[str, float], optional): Coefficients for different loss components.\n aux_loss (bool): Whether to use auxiliary losses from each decoder layer.\n use_fl (bool): Whether to use FocalLoss.\n use_vfl (bool): Whether to use VarifocalLoss.\n use_uni_match (bool): Whether to use fixed layer for auxiliary branch label assignment.\n uni_match_ind (int): Index of fixed layer for uni_match.\n gamma (float): The focusing parameter that controls how much the loss focuses on hard-to-classify examples.\n alpha (float): The balancing factor used to address class imbalance.\n \"\"\"\n super().__init__()\n\n if loss_gain is None:\n loss_gain = {\"class\": 1, \"bbox\": 5, \"giou\": 2, \"no_object\": 0.1, \"mask\": 1, \"dice\": 1}\n self.nc = nc\n self.matcher = HungarianMatcher(cost_gain={\"class\": 2, \"bbox\": 5, \"giou\": 2})\n self.loss_gain = loss_gain\n self.aux_loss = aux_loss\n self.fl = FocalLoss(gamma, alpha) if use_fl else None\n self.vfl = VarifocalLoss(gamma, alpha) if use_vfl else None\n\n self.use_uni_match = use_uni_match\n self.uni_match_ind = uni_match_ind\n self.device = None\n\n def _get_loss_class(\n self, pred_scores: torch.Tensor, targets: torch.Tensor, gt_scores: torch.Tensor, num_gts: int, postfix: str = \"\"\n ) -> Dict[str, torch.Tensor]:\n \"\"\"\n Compute classification loss based on predictions, target values, and ground truth scores.\n\n Args:\n pred_scores (torch.Tensor): Predicted class scores with shape (B, N, C).\n targets (torch.Tensor): Target class indices with shape (B, N).\n gt_scores (torch.Tensor): Ground truth confidence scores with shape (B, N).\n num_gts (int): Number of ground truth objects.\n postfix (str, optional): String to append to the loss name for identification in multi-loss scenarios.\n\n Returns:\n (Dict[str, torch.Tensor]): Dictionary containing classification loss value.\n\n Notes:\n The function supports different classification loss types:\n - Varifocal Loss (if self.vfl is True and num_gts > 0)\n - Focal Loss (if self.fl is True)\n - BCE Loss (default fallback)\n \"\"\"\n # Logits: [b, query, num_classes], gt_class: list[[n, 1]]\n name_class = f\"loss_class{postfix}\"\n bs, nq = pred_scores.shape[:2]\n # one_hot = F.one_hot(targets, self.nc + 1)[..., :-1] # (bs, num_queries, num_classes)\n one_hot = torch.zeros((bs, nq, self.nc + 1), dtype=torch.int64, device=targets.device)\n one_hot.scatter_(2, targets.unsqueeze(-1), 1)\n one_hot = one_hot[..., :-1]\n gt_scores = gt_scores.view(bs, nq, 1) * one_hot\n\n if self.fl:\n if num_gts and self.vfl:\n loss_cls = self.vfl(pred_scores, gt_scores, one_hot)\n else:\n loss_cls = self.fl(pred_scores, one_hot.float())\n loss_cls /= max(num_gts, 1) / nq\n else:\n loss_cls = nn.BCEWithLogitsLoss(reduction=\"none\")(pred_scores, gt_scores).mean(1).sum() # YOLO CLS loss\n\n return {name_class: loss_cls.squeeze() * self.loss_gain[\"class\"]}\n\n def _get_loss_bbox(\n self, pred_bboxes: torch.Tensor, gt_bboxes: torch.Tensor, postfix: str = \"\"\n ) -> Dict[str, torch.Tensor]:\n \"\"\"\n Compute bounding box and GIoU losses for predicted and ground truth bounding boxes.\n\n Args:\n pred_bboxes (torch.Tensor): Predicted bounding boxes with shape (N, 4).\n gt_bboxes (torch.Tensor): Ground truth bounding boxes with shape (N, 4).\n postfix (str, optional): String to append to the loss names for identification in multi-loss scenarios.\n\n Returns:\n (Dict[str, torch.Tensor]): Dictionary containing:\n - loss_bbox{postfix}: L1 loss between predicted and ground truth boxes, scaled by the bbox loss gain.\n - loss_giou{postfix}: GIoU loss between predicted and ground truth boxes, scaled by the giou loss gain.\n\n Notes:\n If no ground truth boxes are provided (empty list), zero-valued tensors are returned for both losses.\n \"\"\"\n # Boxes: [b, query, 4], gt_bbox: list[[n, 4]]\n name_bbox = f\"loss_bbox{postfix}\"\n name_giou = f\"loss_giou{postfix}\"\n\n loss = {}\n if len(gt_bboxes) == 0:\n loss[name_bbox] = torch.tensor(0.0, device=self.device)\n loss[name_giou] = torch.tensor(0.0, device=self.device)\n return loss\n\n loss[name_bbox] = self.loss_gain[\"bbox\"] * F.l1_loss(pred_bboxes, gt_bboxes, reduction=\"sum\") / len(gt_bboxes)\n loss[name_giou] = 1.0 - bbox_iou(pred_bboxes, gt_bboxes, xywh=True, GIoU=True)\n loss[name_giou] = loss[name_giou].sum() / len(gt_bboxes)\n loss[name_giou] = self.loss_gain[\"giou\"] * loss[name_giou]\n return {k: v.squeeze() for k, v in loss.items()}\n\n # This function is for future RT-DETR Segment models\n # def _get_loss_mask(self, masks, gt_mask, match_indices, postfix=''):\n # # masks: [b, query, h, w], gt_mask: list[[n, H, W]]\n # name_mask = f'loss_mask{postfix}'\n # name_dice = f'loss_dice{postfix}'\n #\n # loss = {}\n # if sum(len(a) for a in gt_mask) == 0:\n # loss[name_mask] = torch.tensor(0., device=self.device)\n # loss[name_dice] = torch.tensor(0., device=self.device)\n # return loss\n #\n # num_gts = len(gt_mask)\n # src_masks, target_masks = self._get_assigned_bboxes(masks, gt_mask, match_indices)\n # src_masks = F.interpolate(src_masks.unsqueeze(0), size=target_masks.shape[-2:], mode='bilinear')[0]\n # # TODO: torch does not have `sigmoid_focal_loss`, but it's not urgent since we don't use mask branch for now.\n # loss[name_mask] = self.loss_gain['mask'] * F.sigmoid_focal_loss(src_masks, target_masks,\n # torch.tensor([num_gts], dtype=torch.float32))\n # loss[name_dice] = self.loss_gain['dice'] * self._dice_loss(src_masks, target_masks, num_gts)\n # return loss\n\n # This function is for future RT-DETR Segment models\n # @staticmethod\n # def _dice_loss(inputs, targets, num_gts):\n # inputs = F.sigmoid(inputs).flatten(1)\n # targets = targets.flatten(1)\n # numerator = 2 * (inputs * targets).sum(1)\n # denominator = inputs.sum(-1) + targets.sum(-1)\n # loss = 1 - (numerator + 1) / (denominator + 1)\n # return loss.sum() / num_gts\n\n def _get_loss_aux(\n self,\n pred_bboxes: torch.Tensor,\n pred_scores: torch.Tensor,\n gt_bboxes: torch.Tensor,\n gt_cls: torch.Tensor,\n gt_groups: List[int],\n match_indices: Optional[List[Tuple]] = None,\n postfix: str = \"\",\n masks: Optional[torch.Tensor] = None,\n gt_mask: Optional[torch.Tensor] = None,\n ) -> Dict[str, torch.Tensor]:\n \"\"\"\n Get auxiliary losses for intermediate decoder layers.\n\n Args:\n pred_bboxes (torch.Tensor): Predicted bounding boxes from auxiliary layers.\n pred_scores (torch.Tensor): Predicted scores from auxiliary layers.\n gt_bboxes (torch.Tensor): Ground truth bounding boxes.\n gt_cls (torch.Tensor): Ground truth classes.\n gt_groups (List[int]): Number of ground truths per image.\n match_indices (List[Tuple], optional): Pre-computed matching indices.\n postfix (str, optional): String to append to loss names.\n masks (torch.Tensor, optional): Predicted masks if using segmentation.\n gt_mask (torch.Tensor, optional): Ground truth masks if using segmentation.\n\n Returns:\n (Dict[str, torch.Tensor]): Dictionary of auxiliary losses.\n \"\"\"\n # NOTE: loss class, bbox, giou, mask, dice\n loss = torch.zeros(5 if masks is not None else 3, device=pred_bboxes.device)\n if match_indices is None and self.use_uni_match:\n match_indices = self.matcher(\n pred_bboxes[self.uni_match_ind],\n pred_scores[self.uni_match_ind],\n gt_bboxes,\n gt_cls,\n gt_groups,\n masks=masks[self.uni_match_ind] if masks is not None else None,\n gt_mask=gt_mask,\n )\n for i, (aux_bboxes, aux_scores) in enumerate(zip(pred_bboxes, pred_scores)):\n aux_masks = masks[i] if masks is not None else None\n loss_ = self._get_loss(\n aux_bboxes,\n aux_scores,\n gt_bboxes,\n gt_cls,\n gt_groups,\n masks=aux_masks,\n gt_mask=gt_mask,\n postfix=postfix,\n match_indices=match_indices,\n )\n loss[0] += loss_[f\"loss_class{postfix}\"]\n loss[1] += loss_[f\"loss_bbox{postfix}\"]\n loss[2] += loss_[f\"loss_giou{postfix}\"]\n # if masks is not None and gt_mask is not None:\n # loss_ = self._get_loss_mask(aux_masks, gt_mask, match_indices, postfix)\n # loss[3] += loss_[f'loss_mask{postfix}']\n # loss[4] += loss_[f'loss_dice{postfix}']\n\n loss = {\n f\"loss_class_aux{postfix}\": loss[0],\n f\"loss_bbox_aux{postfix}\": loss[1],\n f\"loss_giou_aux{postfix}\": loss[2],\n }\n # if masks is not None and gt_mask is not None:\n # loss[f'loss_mask_aux{postfix}'] = loss[3]\n # loss[f'loss_dice_aux{postfix}'] = loss[4]\n return loss\n\n @staticmethod\n def _get_index(match_indices: List[Tuple]) -> Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]:\n \"\"\"\n Extract batch indices, source indices, and destination indices from match indices.\n\n Args:\n match_indices (List[Tuple]): List of tuples containing matched indices.\n\n Returns:\n batch_idx (Tuple[torch.Tensor, torch.Tensor]): Tuple containing (batch_idx, src_idx).\n dst_idx (torch.Tensor): Destination indices.\n \"\"\"\n batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(match_indices)])\n src_idx = torch.cat([src for (src, _) in match_indices])\n dst_idx = torch.cat([dst for (_, dst) in match_indices])\n return (batch_idx, src_idx), dst_idx\n\n def _get_assigned_bboxes(\n self, pred_bboxes: torch.Tensor, gt_bboxes: torch.Tensor, match_indices: List[Tuple]\n ) -> Tuple[torch.Tensor, torch.Tensor]:\n \"\"\"\n Assign predicted bounding boxes to ground truth bounding boxes based on match indices.\n\n Args:\n pred_bboxes (torch.Tensor): Predicted bounding boxes.\n gt_bboxes (torch.Tensor): Ground truth bounding boxes.\n match_indices (List[Tuple]): List of tuples containing matched indices.\n\n Returns:\n pred_assigned (torch.Tensor): Assigned predicted bounding boxes.\n gt_assigned (torch.Tensor): Assigned ground truth bounding boxes.\n \"\"\"\n pred_assigned = torch.cat(\n [\n t[i] if len(i) > 0 else torch.zeros(0, t.shape[-1], device=self.device)\n for t, (i, _) in zip(pred_bboxes, match_indices)\n ]\n )\n gt_assigned = torch.cat(\n [\n t[j] if len(j) > 0 else torch.zeros(0, t.shape[-1], device=self.device)\n for t, (_, j) in zip(gt_bboxes, match_indices)\n ]\n )\n return pred_assigned, gt_assigned\n\n def _get_loss(\n self,\n pred_bboxes: torch.Tensor,\n pred_scores: torch.Tensor,\n gt_bboxes: torch.Tensor,\n gt_cls: torch.Tensor,\n gt_groups: List[int],\n masks: Optional[torch.Tensor] = None,\n gt_mask: Optional[torch.Tensor] = None,\n postfix: str = \"\",\n match_indices: Optional[List[Tuple]] = None,\n ) -> Dict[str, torch.Tensor]:\n \"\"\"\n Calculate losses for a single prediction layer.\n\n Args:\n pred_bboxes (torch.Tensor): Predicted bounding boxes.\n pred_scores (torch.Tensor): Predicted class scores.\n gt_bboxes (torch.Tensor): Ground truth bounding boxes.\n gt_cls (torch.Tensor): Ground truth classes.\n gt_groups (List[int]): Number of ground truths per image.\n masks (torch.Tensor, optional): Predicted masks if using segmentation.\n gt_mask (torch.Tensor, optional): Ground truth masks if using segmentation.\n postfix (str, optional): String to append to loss names.\n match_indices (List[Tuple], optional): Pre-computed matching indices.\n\n Returns:\n (Dict[str, torch.Tensor]): Dictionary of losses.\n \"\"\"\n if match_indices is None:\n match_indices = self.matcher(\n pred_bboxes, pred_scores, gt_bboxes, gt_cls, gt_groups, masks=masks, gt_mask=gt_mask\n )\n\n idx, gt_idx = self._get_index(match_indices)\n pred_bboxes, gt_bboxes = pred_bboxes[idx], gt_bboxes[gt_idx]\n\n bs, nq = pred_scores.shape[:2]\n targets = torch.full((bs, nq), self.nc, device=pred_scores.device, dtype=gt_cls.dtype)\n targets[idx] = gt_cls[gt_idx]\n\n gt_scores = torch.zeros([bs, nq], device=pred_scores.device)\n if len(gt_bboxes):\n gt_scores[idx] = bbox_iou(pred_bboxes.detach(), gt_bboxes, xywh=True).squeeze(-1)\n\n return {\n **self._get_loss_class(pred_scores, targets, gt_scores, len(gt_bboxes), postfix),\n **self._get_loss_bbox(pred_bboxes, gt_bboxes, postfix),\n # **(self._get_loss_mask(masks, gt_mask, match_indices, postfix) if masks is not None and gt_mask is not None else {})\n }\n\n def forward(\n self,\n pred_bboxes: torch.Tensor,\n pred_scores: torch.Tensor,\n batch: Dict[str, Any],\n postfix: str = \"\",\n **kwargs: Any,\n ) -> Dict[str, torch.Tensor]:\n \"\"\"\n Calculate loss for predicted bounding boxes and scores.\n\n Args:\n pred_bboxes (torch.Tensor): Predicted bounding boxes, shape (L, B, N, 4).\n pred_scores (torch.Tensor): Predicted class scores, shape (L, B, N, C).\n batch (Dict[str, Any]): Batch information containing cls, bboxes, and gt_groups.\n postfix (str, optional): Postfix for loss names.\n **kwargs (Any): Additional arguments, may include 'match_indices'.\n\n Returns:\n (Dict[str, torch.Tensor]): Computed losses, including main and auxiliary (if enabled).\n\n Notes:\n Uses last elements of pred_bboxes and pred_scores for main loss, and the rest for auxiliary losses if\n self.aux_loss is True.\n \"\"\"\n self.device = pred_bboxes.device\n match_indices = kwargs.get(\"match_indices\", None)\n gt_cls, gt_bboxes, gt_groups = batch[\"cls\"], batch[\"bboxes\"], batch[\"gt_groups\"]\n\n total_loss = self._get_loss(\n pred_bboxes[-1], pred_scores[-1], gt_bboxes, gt_cls, gt_groups, postfix=postfix, match_indices=match_indices\n )\n\n if self.aux_loss:\n total_loss.update(\n self._get_loss_aux(\n pred_bboxes[:-1], pred_scores[:-1], gt_bboxes, gt_cls, gt_groups, match_indices, postfix\n )\n )\n\n return total_loss", "chunk_type": "class", "name": "DETRLoss", "file_path": "ultralytics\\ultralytics\\models\\utils\\loss.py", "start_line": 15, "end_line": 397, "start_col": 0, "end_col": 25, "parent_name": null, "docstring": "DETR (DEtection TRansformer) Loss class for calculating various loss components.\n\nThis class computes classification loss, bounding box loss, GIoU loss, and optionally auxiliary losses for the\nDETR object detection model.\n\nAttributes:\n nc (int): Number of classes.\n loss_gain (Dict[str, float]): Coefficients for different loss components.\n aux_loss (bool): Whether to compute auxiliary losses.\n use_fl (bool): Whether to use FocalLoss.\n use_vfl (bool): Whether to use VarifocalLoss.\n use_uni_match (bool): Whether to use a fixed layer for auxiliary branch label assignment.\n uni_match_ind (int): Index of fixed layer to use if use_uni_match is True.\n matcher (HungarianMatcher): Object to compute matching cost and indices.\n fl (FocalLoss | None): Focal Loss object if use_fl is True, otherwise None.\n vfl (VarifocalLoss | None): Varifocal Loss object if use_vfl is True, otherwise None.\n device (torch.device): Device on which tensors are stored.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.Any", "typing.Dict", "typing.List", "typing.Optional", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.loss.FocalLoss", "ultralytics.utils.loss.VarifocalLoss", "ultralytics.utils.metrics.bbox_iou", "ops.HungarianMatcher", "nn.Module" ], "chunk_id": "class_DETRLoss_5541987f" }, { "content": "class RTDETRDetectionLoss(DETRLoss):\n \"\"\"\n Real-Time DeepTracker (RT-DETR) Detection Loss class that extends the DETRLoss.\n\n This class computes the detection loss for the RT-DETR model, which includes the standard detection loss as well as\n an additional denoising training loss when provided with denoising metadata.\n \"\"\"\n\n def forward(\n self,\n preds: Tuple[torch.Tensor, torch.Tensor],\n batch: Dict[str, Any],\n dn_bboxes: Optional[torch.Tensor] = None,\n dn_scores: Optional[torch.Tensor] = None,\n dn_meta: Optional[Dict[str, Any]] = None,\n ) -> Dict[str, torch.Tensor]:\n \"\"\"\n Forward pass to compute detection loss with optional denoising loss.\n\n Args:\n preds (Tuple[torch.Tensor, torch.Tensor]): Tuple containing predicted bounding boxes and scores.\n batch (Dict[str, Any]): Batch data containing ground truth information.\n dn_bboxes (torch.Tensor, optional): Denoising bounding boxes.\n dn_scores (torch.Tensor, optional): Denoising scores.\n dn_meta (Dict[str, Any], optional): Metadata for denoising.\n\n Returns:\n (Dict[str, torch.Tensor]): Dictionary containing total loss and denoising loss if applicable.\n \"\"\"\n pred_bboxes, pred_scores = preds\n total_loss = super().forward(pred_bboxes, pred_scores, batch)\n\n # Check for denoising metadata to compute denoising training loss\n if dn_meta is not None:\n dn_pos_idx, dn_num_group = dn_meta[\"dn_pos_idx\"], dn_meta[\"dn_num_group\"]\n assert len(batch[\"gt_groups\"]) == len(dn_pos_idx)\n\n # Get the match indices for denoising\n match_indices = self.get_dn_match_indices(dn_pos_idx, dn_num_group, batch[\"gt_groups\"])\n\n # Compute the denoising training loss\n dn_loss = super().forward(dn_bboxes, dn_scores, batch, postfix=\"_dn\", match_indices=match_indices)\n total_loss.update(dn_loss)\n else:\n # If no denoising metadata is provided, set denoising loss to zero\n total_loss.update({f\"{k}_dn\": torch.tensor(0.0, device=self.device) for k in total_loss.keys()})\n\n return total_loss\n\n @staticmethod\n def get_dn_match_indices(\n dn_pos_idx: List[torch.Tensor], dn_num_group: int, gt_groups: List[int]\n ) -> List[Tuple[torch.Tensor, torch.Tensor]]:\n \"\"\"\n Get match indices for denoising.\n\n Args:\n dn_pos_idx (List[torch.Tensor]): List of tensors containing positive indices for denoising.\n dn_num_group (int): Number of denoising groups.\n gt_groups (List[int]): List of integers representing number of ground truths per image.\n\n Returns:\n (List[Tuple[torch.Tensor, torch.Tensor]]): List of tuples containing matched indices for denoising.\n \"\"\"\n dn_match_indices = []\n idx_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0)\n for i, num_gt in enumerate(gt_groups):\n if num_gt > 0:\n gt_idx = torch.arange(end=num_gt, dtype=torch.long) + idx_groups[i]\n gt_idx = gt_idx.repeat(dn_num_group)\n assert len(dn_pos_idx[i]) == len(gt_idx), (\n f\"Expected the same length, but got {len(dn_pos_idx[i])} and {len(gt_idx)} respectively.\"\n )\n dn_match_indices.append((dn_pos_idx[i], gt_idx))\n else:\n dn_match_indices.append((torch.zeros([0], dtype=torch.long), torch.zeros([0], dtype=torch.long)))\n return dn_match_indices", "chunk_type": "class", "name": "RTDETRDetectionLoss", "file_path": "ultralytics\\ultralytics\\models\\utils\\loss.py", "start_line": 400, "end_line": 476, "start_col": 0, "end_col": 31, "parent_name": null, "docstring": "Real-Time DeepTracker (RT-DETR) Detection Loss class that extends the DETRLoss.\n\nThis class computes the detection loss for the RT-DETR model, which includes the standard detection loss as well as\nan additional denoising training loss when provided with denoising metadata.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.Any", "typing.Dict", "typing.List", "typing.Optional", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.loss.FocalLoss", "ultralytics.utils.loss.VarifocalLoss", "ultralytics.utils.metrics.bbox_iou", "ops.HungarianMatcher", "DETRLoss" ], "chunk_id": "class_RTDETRDetectionLoss_bddeaaab" }, { "content": "from typing import Any, Dict, List, Optional, Tuple", "chunk_type": "import", "name": "Any, Dict, List, Optional, Tuple", "file_path": "ultralytics\\ultralytics\\models\\utils\\ops.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 51, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Any, Dict, List, Optional, Tuple_2a5a3362" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\ultralytics\\models\\utils\\ops.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_337a1d94" }, { "content": "import torch.nn as nn", "chunk_type": "import", "name": "torch.nn", "file_path": "ultralytics\\ultralytics\\models\\utils\\ops.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 21, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch.nn_8ceb190b" }, { "content": "import torch.nn.functional as F", "chunk_type": "import", "name": "torch.nn.functional", "file_path": "ultralytics\\ultralytics\\models\\utils\\ops.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 31, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch.nn.functional_6482ad80" }, { "content": "from scipy.optimize import linear_sum_assignment", "chunk_type": "import", "name": "linear_sum_assignment", "file_path": "ultralytics\\ultralytics\\models\\utils\\ops.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 48, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_linear_sum_assignment_e319a284" }, { "content": "from ultralytics.utils.metrics import bbox_iou", "chunk_type": "import", "name": "bbox_iou", "file_path": "ultralytics\\ultralytics\\models\\utils\\ops.py", "start_line": 10, "end_line": 10, "start_col": 0, "end_col": 46, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_bbox_iou_a1d5a92c" }, { "content": "from ultralytics.utils.ops import xywh2xyxy, xyxy2xywh", "chunk_type": "import", "name": "xywh2xyxy, xyxy2xywh", "file_path": "ultralytics\\ultralytics\\models\\utils\\ops.py", "start_line": 11, "end_line": 11, "start_col": 0, "end_col": 54, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_xywh2xyxy, xyxy2xywh_74ca1074" }, { "content": "class HungarianMatcher(nn.Module):\n \"\"\"\n A module implementing the HungarianMatcher for optimal assignment between predictions and ground truth.\n\n HungarianMatcher performs optimal bipartite assignment over predicted and ground truth bounding boxes using a cost\n function that considers classification scores, bounding box coordinates, and optionally mask predictions. This is\n used in end-to-end object detection models like DETR.\n\n Attributes:\n cost_gain (Dict[str, float]): Dictionary of cost coefficients for 'class', 'bbox', 'giou', 'mask', and 'dice'\n components.\n use_fl (bool): Whether to use Focal Loss for classification cost calculation.\n with_mask (bool): Whether the model makes mask predictions.\n num_sample_points (int): Number of sample points used in mask cost calculation.\n alpha (float): Alpha factor in Focal Loss calculation.\n gamma (float): Gamma factor in Focal Loss calculation.\n\n Methods:\n forward: Compute optimal assignment between predictions and ground truths for a batch.\n _cost_mask: Compute mask cost and dice cost if masks are predicted.\n\n Examples:\n Initialize a HungarianMatcher with custom cost gains\n >>> matcher = HungarianMatcher(cost_gain={\"class\": 2, \"bbox\": 5, \"giou\": 2})\n\n Perform matching between predictions and ground truth\n >>> pred_boxes = torch.rand(2, 100, 4) # batch_size=2, num_queries=100\n >>> pred_scores = torch.rand(2, 100, 80) # 80 classes\n >>> gt_boxes = torch.rand(10, 4) # 10 ground truth boxes\n >>> gt_classes = torch.randint(0, 80, (10,))\n >>> gt_groups = [5, 5] # 5 GT boxes per image\n >>> indices = matcher(pred_boxes, pred_scores, gt_boxes, gt_classes, gt_groups)\n \"\"\"\n\n def __init__(\n self,\n cost_gain: Optional[Dict[str, float]] = None,\n use_fl: bool = True,\n with_mask: bool = False,\n num_sample_points: int = 12544,\n alpha: float = 0.25,\n gamma: float = 2.0,\n ):\n \"\"\"\n Initialize HungarianMatcher for optimal assignment of predicted and ground truth bounding boxes.\n\n Args:\n cost_gain (Dict[str, float], optional): Dictionary of cost coefficients for different matching cost\n components. Should contain keys 'class', 'bbox', 'giou', 'mask', and 'dice'.\n use_fl (bool): Whether to use Focal Loss for classification cost calculation.\n with_mask (bool): Whether the model makes mask predictions.\n num_sample_points (int): Number of sample points used in mask cost calculation.\n alpha (float): Alpha factor in Focal Loss calculation.\n gamma (float): Gamma factor in Focal Loss calculation.\n \"\"\"\n super().__init__()\n if cost_gain is None:\n cost_gain = {\"class\": 1, \"bbox\": 5, \"giou\": 2, \"mask\": 1, \"dice\": 1}\n self.cost_gain = cost_gain\n self.use_fl = use_fl\n self.with_mask = with_mask\n self.num_sample_points = num_sample_points\n self.alpha = alpha\n self.gamma = gamma\n\n def forward(\n self,\n pred_bboxes: torch.Tensor,\n pred_scores: torch.Tensor,\n gt_bboxes: torch.Tensor,\n gt_cls: torch.Tensor,\n gt_groups: List[int],\n masks: Optional[torch.Tensor] = None,\n gt_mask: Optional[List[torch.Tensor]] = None,\n ) -> List[Tuple[torch.Tensor, torch.Tensor]]:\n \"\"\"\n Compute optimal assignment between predictions and ground truth using Hungarian algorithm.\n\n This method calculates matching costs based on classification scores, bounding box coordinates, and optionally\n mask predictions, then finds the optimal bipartite assignment between predictions and ground truth.\n\n Args:\n pred_bboxes (torch.Tensor): Predicted bounding boxes with shape (batch_size, num_queries, 4).\n pred_scores (torch.Tensor): Predicted classification scores with shape (batch_size, num_queries,\n num_classes).\n gt_bboxes (torch.Tensor): Ground truth bounding boxes with shape (num_gts, 4).\n gt_cls (torch.Tensor): Ground truth class labels with shape (num_gts,).\n gt_groups (List[int]): Number of ground truth boxes for each image in the batch.\n masks (torch.Tensor, optional): Predicted masks with shape (batch_size, num_queries, height, width).\n gt_mask (List[torch.Tensor], optional): Ground truth masks, each with shape (num_masks, Height, Width).\n\n Returns:\n (List[Tuple[torch.Tensor, torch.Tensor]]): A list of size batch_size, each element is a tuple\n (index_i, index_j), where index_i is the tensor of indices of the selected predictions (in order)\n and index_j is the tensor of indices of the corresponding selected ground truth targets (in order).\n For each batch element, it holds: len(index_i) = len(index_j) = min(num_queries, num_target_boxes).\n \"\"\"\n bs, nq, nc = pred_scores.shape\n\n if sum(gt_groups) == 0:\n return [(torch.tensor([], dtype=torch.long), torch.tensor([], dtype=torch.long)) for _ in range(bs)]\n\n # Flatten to compute cost matrices in batch format\n pred_scores = pred_scores.detach().view(-1, nc)\n pred_scores = F.sigmoid(pred_scores) if self.use_fl else F.softmax(pred_scores, dim=-1)\n pred_bboxes = pred_bboxes.detach().view(-1, 4)\n\n # Compute classification cost\n pred_scores = pred_scores[:, gt_cls]\n if self.use_fl:\n neg_cost_class = (1 - self.alpha) * (pred_scores**self.gamma) * (-(1 - pred_scores + 1e-8).log())\n pos_cost_class = self.alpha * ((1 - pred_scores) ** self.gamma) * (-(pred_scores + 1e-8).log())\n cost_class = pos_cost_class - neg_cost_class\n else:\n cost_class = -pred_scores\n\n # Compute L1 cost between boxes\n cost_bbox = (pred_bboxes.unsqueeze(1) - gt_bboxes.unsqueeze(0)).abs().sum(-1) # (bs*num_queries, num_gt)\n\n # Compute GIoU cost between boxes, (bs*num_queries, num_gt)\n cost_giou = 1.0 - bbox_iou(pred_bboxes.unsqueeze(1), gt_bboxes.unsqueeze(0), xywh=True, GIoU=True).squeeze(-1)\n\n # Combine costs into final cost matrix\n C = (\n self.cost_gain[\"class\"] * cost_class\n + self.cost_gain[\"bbox\"] * cost_bbox\n + self.cost_gain[\"giou\"] * cost_giou\n )\n\n # Add mask costs if available\n if self.with_mask:\n C += self._cost_mask(bs, gt_groups, masks, gt_mask)\n\n # Set invalid values (NaNs and infinities) to 0\n C[C.isnan() | C.isinf()] = 0.0\n\n C = C.view(bs, nq, -1).cpu()\n indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(gt_groups, -1))]\n gt_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0) # (idx for queries, idx for gt)\n return [\n (torch.tensor(i, dtype=torch.long), torch.tensor(j, dtype=torch.long) + gt_groups[k])\n for k, (i, j) in enumerate(indices)\n ]", "chunk_type": "class", "name": "HungarianMatcher", "file_path": "ultralytics\\ultralytics\\models\\utils\\ops.py", "start_line": 14, "end_line": 156, "start_col": 0, "end_col": 9, "parent_name": null, "docstring": "A module implementing the HungarianMatcher for optimal assignment between predictions and ground truth.\n\nHungarianMatcher performs optimal bipartite assignment over predicted and ground truth bounding boxes using a cost\nfunction that considers classification scores, bounding box coordinates, and optionally mask predictions. This is\nused in end-to-end object detection models like DETR.\n\nAttributes:\n cost_gain (Dict[str, float]): Dictionary of cost coefficients for 'class', 'bbox', 'giou', 'mask', and 'dice'\n components.\n use_fl (bool): Whether to use Focal Loss for classification cost calculation.\n with_mask (bool): Whether the model makes mask predictions.\n num_sample_points (int): Number of sample points used in mask cost calculation.\n alpha (float): Alpha factor in Focal Loss calculation.\n gamma (float): Gamma factor in Focal Loss calculation.\n\nMethods:\n forward: Compute optimal assignment between predictions and ground truths for a batch.\n _cost_mask: Compute mask cost and dice cost if masks are predicted.\n\nExamples:\n Initialize a HungarianMatcher with custom cost gains\n >>> matcher = HungarianMatcher(cost_gain={\"class\": 2, \"bbox\": 5, \"giou\": 2})\n\n Perform matching between predictions and ground truth\n >>> pred_boxes = torch.rand(2, 100, 4) # batch_size=2, num_queries=100\n >>> pred_scores = torch.rand(2, 100, 80) # 80 classes\n >>> gt_boxes = torch.rand(10, 4) # 10 ground truth boxes\n >>> gt_classes = torch.randint(0, 80, (10,))\n >>> gt_groups = [5, 5] # 5 GT boxes per image\n >>> indices = matcher(pred_boxes, pred_scores, gt_boxes, gt_classes, gt_groups)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.Any", "typing.Dict", "typing.List", "typing.Optional", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "scipy.optimize.linear_sum_assignment", "ultralytics.utils.metrics.bbox_iou", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xyxy2xywh", "nn.Module" ], "chunk_id": "class_HungarianMatcher_ec3aa111" }, { "content": "def get_cdn_group(\n batch: Dict[str, Any],\n num_classes: int,\n num_queries: int,\n class_embed: torch.Tensor,\n num_dn: int = 100,\n cls_noise_ratio: float = 0.5,\n box_noise_scale: float = 1.0,\n training: bool = False,\n) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[Dict[str, Any]]]:\n \"\"\"\n Generate contrastive denoising training group with positive and negative samples from ground truths.\n\n This function creates denoising queries for contrastive denoising training by adding noise to ground truth\n bounding boxes and class labels. It generates both positive and negative samples to improve model robustness.\n\n Args:\n batch (Dict[str, Any]): Batch dictionary containing 'gt_cls' (torch.Tensor with shape (num_gts,)),\n 'gt_bboxes' (torch.Tensor with shape (num_gts, 4)), and 'gt_groups' (List[int]) indicating number of\n ground truths per image.\n num_classes (int): Total number of object classes.\n num_queries (int): Number of object queries.\n class_embed (torch.Tensor): Class embedding weights to map labels to embedding space.\n num_dn (int): Number of denoising queries to generate.\n cls_noise_ratio (float): Noise ratio for class labels.\n box_noise_scale (float): Noise scale for bounding box coordinates.\n training (bool): Whether model is in training mode.\n\n Returns:\n padding_cls (torch.Tensor | None): Modified class embeddings for denoising with shape (bs, num_dn, embed_dim).\n padding_bbox (torch.Tensor | None): Modified bounding boxes for denoising with shape (bs, num_dn, 4).\n attn_mask (torch.Tensor | None): Attention mask for denoising with shape (tgt_size, tgt_size).\n dn_meta (Dict[str, Any] | None): Meta information dictionary containing denoising parameters.\n\n Examples:\n Generate denoising group for training\n >>> batch = {\n ... \"cls\": torch.tensor([0, 1, 2]),\n ... \"bboxes\": torch.rand(3, 4),\n ... \"batch_idx\": torch.tensor([0, 0, 1]),\n ... \"gt_groups\": [2, 1],\n ... }\n >>> class_embed = torch.rand(80, 256) # 80 classes, 256 embedding dim\n >>> cdn_outputs = get_cdn_group(batch, 80, 100, class_embed, training=True)\n \"\"\"\n if (not training) or num_dn <= 0 or batch is None:\n return None, None, None, None\n gt_groups = batch[\"gt_groups\"]\n total_num = sum(gt_groups)\n max_nums = max(gt_groups)\n if max_nums == 0:\n return None, None, None, None\n\n num_group = num_dn // max_nums\n num_group = 1 if num_group == 0 else num_group\n # Pad gt to max_num of a batch\n bs = len(gt_groups)\n gt_cls = batch[\"cls\"] # (bs*num, )\n gt_bbox = batch[\"bboxes\"] # bs*num, 4\n b_idx = batch[\"batch_idx\"]\n\n # Each group has positive and negative queries\n dn_cls = gt_cls.repeat(2 * num_group) # (2*num_group*bs*num, )\n dn_bbox = gt_bbox.repeat(2 * num_group, 1) # 2*num_group*bs*num, 4\n dn_b_idx = b_idx.repeat(2 * num_group).view(-1) # (2*num_group*bs*num, )\n\n # Positive and negative mask\n # (bs*num*num_group, ), the second total_num*num_group part as negative samples\n neg_idx = torch.arange(total_num * num_group, dtype=torch.long, device=gt_bbox.device) + num_group * total_num\n\n if cls_noise_ratio > 0:\n # Apply class label noise to half of the samples\n mask = torch.rand(dn_cls.shape) < (cls_noise_ratio * 0.5)\n idx = torch.nonzero(mask).squeeze(-1)\n # Randomly assign new class labels\n new_label = torch.randint_like(idx, 0, num_classes, dtype=dn_cls.dtype, device=dn_cls.device)\n dn_cls[idx] = new_label\n\n if box_noise_scale > 0:\n known_bbox = xywh2xyxy(dn_bbox)\n\n diff = (dn_bbox[..., 2:] * 0.5).repeat(1, 2) * box_noise_scale # 2*num_group*bs*num, 4\n\n rand_sign = torch.randint_like(dn_bbox, 0, 2) * 2.0 - 1.0\n rand_part = torch.rand_like(dn_bbox)\n rand_part[neg_idx] += 1.0\n rand_part *= rand_sign\n known_bbox += rand_part * diff\n known_bbox.clip_(min=0.0, max=1.0)\n dn_bbox = xyxy2xywh(known_bbox)\n dn_bbox = torch.logit(dn_bbox, eps=1e-6) # inverse sigmoid\n\n num_dn = int(max_nums * 2 * num_group) # total denoising queries\n dn_cls_embed = class_embed[dn_cls] # bs*num * 2 * num_group, 256\n padding_cls = torch.zeros(bs, num_dn, dn_cls_embed.shape[-1], device=gt_cls.device)\n padding_bbox = torch.zeros(bs, num_dn, 4, device=gt_bbox.device)\n\n map_indices = torch.cat([torch.tensor(range(num), dtype=torch.long) for num in gt_groups])\n pos_idx = torch.stack([map_indices + max_nums * i for i in range(num_group)], dim=0)\n\n map_indices = torch.cat([map_indices + max_nums * i for i in range(2 * num_group)])\n padding_cls[(dn_b_idx, map_indices)] = dn_cls_embed\n padding_bbox[(dn_b_idx, map_indices)] = dn_bbox\n\n tgt_size = num_dn + num_queries\n attn_mask = torch.zeros([tgt_size, tgt_size], dtype=torch.bool)\n # Match query cannot see the reconstruct\n attn_mask[num_dn:, :num_dn] = True\n # Reconstruct cannot see each other\n for i in range(num_group):\n if i == 0:\n attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), max_nums * 2 * (i + 1) : num_dn] = True\n if i == num_group - 1:\n attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), : max_nums * i * 2] = True\n else:\n attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), max_nums * 2 * (i + 1) : num_dn] = True\n attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), : max_nums * 2 * i] = True\n dn_meta = {\n \"dn_pos_idx\": [p.reshape(-1) for p in pos_idx.cpu().split(list(gt_groups), dim=1)],\n \"dn_num_group\": num_group,\n \"dn_num_split\": [num_dn, num_queries],\n }\n\n return (\n padding_cls.to(class_embed.device),\n padding_bbox.to(class_embed.device),\n attn_mask.to(class_embed.device),\n dn_meta,\n )", "chunk_type": "function", "name": "get_cdn_group", "file_path": "ultralytics\\ultralytics\\models\\utils\\ops.py", "start_line": 189, "end_line": 317, "start_col": 0, "end_col": 5, "parent_name": null, "docstring": "Generate contrastive denoising training group with positive and negative samples from ground truths.\n\nThis function creates denoising queries for contrastive denoising training by adding noise to ground truth\nbounding boxes and class labels. It generates both positive and negative samples to improve model robustness.\n\nArgs:\n batch (Dict[str, Any]): Batch dictionary containing 'gt_cls' (torch.Tensor with shape (num_gts,)),\n 'gt_bboxes' (torch.Tensor with shape (num_gts, 4)), and 'gt_groups' (List[int]) indicating number of\n ground truths per image.\n num_classes (int): Total number of object classes.\n num_queries (int): Number of object queries.\n class_embed (torch.Tensor): Class embedding weights to map labels to embedding space.\n num_dn (int): Number of denoising queries to generate.\n cls_noise_ratio (float): Noise ratio for class labels.\n box_noise_scale (float): Noise scale for bounding box coordinates.\n training (bool): Whether model is in training mode.\n\nReturns:\n padding_cls (torch.Tensor | None): Modified class embeddings for denoising with shape (bs, num_dn, embed_dim).\n padding_bbox (torch.Tensor | None): Modified bounding boxes for denoising with shape (bs, num_dn, 4).\n attn_mask (torch.Tensor | None): Attention mask for denoising with shape (tgt_size, tgt_size).\n dn_meta (Dict[str, Any] | None): Meta information dictionary containing denoising parameters.\n\nExamples:\n Generate denoising group for training\n >>> batch = {\n ... \"cls\": torch.tensor([0, 1, 2]),\n ... \"bboxes\": torch.rand(3, 4),\n ... \"batch_idx\": torch.tensor([0, 0, 1]),\n ... \"gt_groups\": [2, 1],\n ... }\n >>> class_embed = torch.rand(80, 256) # 80 classes, 256 embedding dim\n >>> cdn_outputs = get_cdn_group(batch, 80, 100, class_embed, training=True)", "parameters": [ "batch: Dict[str, Any]", "num_classes: int", "num_queries: int", "class_embed: torch.Tensor", "num_dn: int", "cls_noise_ratio: float", "box_noise_scale: float", "training: bool" ], "return_type": "Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[Dict[str, Any]]]", "decorators": [], "complexity_score": 12, "dependencies": [ "typing.Any", "typing.Dict", "typing.List", "typing.Optional", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "scipy.optimize.linear_sum_assignment", "ultralytics.utils.metrics.bbox_iou", "ultralytics.utils.ops.xywh2xyxy", "ultralytics.utils.ops.xyxy2xywh" ], "chunk_id": "function_get_cdn_group_2042d0e1" }, { "content": "from pathlib import Path", "chunk_type": "import", "name": "Path", "file_path": "ultralytics\\ultralytics\\models\\yolo\\model.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Path_dd98d879" }, { "content": "from typing import Any, Dict, List, Optional, Union", "chunk_type": "import", "name": "Any, Dict, List, Optional, Union", "file_path": "ultralytics\\ultralytics\\models\\yolo\\model.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 51, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Any, Dict, List, Optional, Union_b1bbec1d" }, { "content": "from ultralytics.data.build import load_inference_source", "chunk_type": "import", "name": "load_inference_source", "file_path": "ultralytics\\ultralytics\\models\\yolo\\model.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 56, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_load_inference_source_d794da92" }, { "content": "from ultralytics.engine.model import Model", "chunk_type": "import", "name": "Model", "file_path": "ultralytics\\ultralytics\\models\\yolo\\model.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 42, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Model_4c74fc3d" }, { "content": "from ultralytics.models import yolo", "chunk_type": "import", "name": "yolo", "file_path": "ultralytics\\ultralytics\\models\\yolo\\model.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 35, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_yolo_c5d758a9" }, { "content": "from ultralytics.nn.tasks import (\n ClassificationModel,\n DetectionModel,\n OBBModel,\n PoseModel,\n SegmentationModel,\n WorldModel,\n YOLOEModel,\n YOLOESegModel,\n)", "chunk_type": "import", "name": "ClassificationModel, DetectionModel, OBBModel, PoseModel, SegmentationModel, WorldModel, YOLOEModel, YOLOESegModel", "file_path": "ultralytics\\ultralytics\\models\\yolo\\model.py", "start_line": 9, "end_line": 18, "start_col": 0, "end_col": 1, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_ClassificationModel, DetectionModel, OBBModel, PoseModel, SegmentationModel, WorldModel, YOLOEModel, YOLOESegModel_27430974" }, { "content": "from ultralytics.utils import ROOT, YAML", "chunk_type": "import", "name": "ROOT, YAML", "file_path": "ultralytics\\ultralytics\\models\\yolo\\model.py", "start_line": 19, "end_line": 19, "start_col": 0, "end_col": 40, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_ROOT, YAML_b0400d76" }, { "content": "class YOLO(Model):\n \"\"\"\n YOLO (You Only Look Once) object detection model.\n\n This class provides a unified interface for YOLO models, automatically switching to specialized model types\n (YOLOWorld or YOLOE) based on the model filename. It supports various computer vision tasks including object\n detection, segmentation, classification, pose estimation, and oriented bounding box detection.\n\n Attributes:\n model: The loaded YOLO model instance.\n task: The task type (detect, segment, classify, pose, obb).\n overrides: Configuration overrides for the model.\n\n Methods:\n __init__: Initialize a YOLO model with automatic type detection.\n task_map: Map tasks to their corresponding model, trainer, validator, and predictor classes.\n\n Examples:\n Load a pretrained YOLOv11n detection model\n >>> model = YOLO(\"yolo11n.pt\")\n\n Load a pretrained YOLO11n segmentation model\n >>> model = YOLO(\"yolo11n-seg.pt\")\n\n Initialize from a YAML configuration\n >>> model = YOLO(\"yolo11n.yaml\")\n \"\"\"\n\n def __init__(self, model: Union[str, Path] = \"yolo11n.pt\", task: Optional[str] = None, verbose: bool = False):\n \"\"\"\n Initialize a YOLO model.\n\n This constructor initializes a YOLO model, automatically switching to specialized model types\n (YOLOWorld or YOLOE) based on the model filename.\n\n Args:\n model (str | Path): Model name or path to model file, i.e. 'yolo11n.pt', 'yolo11n.yaml'.\n task (str, optional): YOLO task specification, i.e. 'detect', 'segment', 'classify', 'pose', 'obb'.\n Defaults to auto-detection based on model.\n verbose (bool): Display model info on load.\n\n Examples:\n >>> from ultralytics import YOLO\n >>> model = YOLO(\"yolo11n.pt\") # load a pretrained YOLOv11n detection model\n >>> model = YOLO(\"yolo11n-seg.pt\") # load a pretrained YOLO11n segmentation model\n \"\"\"\n path = Path(model if isinstance(model, (str, Path)) else \"\")\n if \"-world\" in path.stem and path.suffix in {\".pt\", \".yaml\", \".yml\"}: # if YOLOWorld PyTorch model\n new_instance = YOLOWorld(path, verbose=verbose)\n self.__class__ = type(new_instance)\n self.__dict__ = new_instance.__dict__\n elif \"yoloe\" in path.stem and path.suffix in {\".pt\", \".yaml\", \".yml\"}: # if YOLOE PyTorch model\n new_instance = YOLOE(path, task=task, verbose=verbose)\n self.__class__ = type(new_instance)\n self.__dict__ = new_instance.__dict__\n else:\n # Continue with default YOLO initialization\n super().__init__(model=model, task=task, verbose=verbose)\n if hasattr(self.model, \"model\") and \"RTDETR\" in self.model.model[-1]._get_name(): # if RTDETR head\n from ultralytics import RTDETR\n\n new_instance = RTDETR(self)\n self.__class__ = type(new_instance)\n self.__dict__ = new_instance.__dict__\n\n @property\n def task_map(self) -> Dict[str, Dict[str, Any]]:\n \"\"\"Map head to model, trainer, validator, and predictor classes.\"\"\"\n return {\n \"classify\": {\n \"model\": ClassificationModel,\n \"trainer\": yolo.classify.ClassificationTrainer,\n \"validator\": yolo.classify.ClassificationValidator,\n \"predictor\": yolo.classify.ClassificationPredictor,\n },\n \"detect\": {\n \"model\": DetectionModel,\n \"trainer\": yolo.detect.DetectionTrainer,\n \"validator\": yolo.detect.DetectionValidator,\n \"predictor\": yolo.detect.DetectionPredictor,\n },\n \"segment\": {\n \"model\": SegmentationModel,\n \"trainer\": yolo.segment.SegmentationTrainer,\n \"validator\": yolo.segment.SegmentationValidator,\n \"predictor\": yolo.segment.SegmentationPredictor,\n },\n \"pose\": {\n \"model\": PoseModel,\n \"trainer\": yolo.pose.PoseTrainer,\n \"validator\": yolo.pose.PoseValidator,\n \"predictor\": yolo.pose.PosePredictor,\n },\n \"obb\": {\n \"model\": OBBModel,\n \"trainer\": yolo.obb.OBBTrainer,\n \"validator\": yolo.obb.OBBValidator,\n \"predictor\": yolo.obb.OBBPredictor,\n },\n }", "chunk_type": "class", "name": "YOLO", "file_path": "ultralytics\\ultralytics\\models\\yolo\\model.py", "start_line": 22, "end_line": 121, "start_col": 0, "end_col": 9, "parent_name": null, "docstring": "YOLO (You Only Look Once) object detection model.\n\nThis class provides a unified interface for YOLO models, automatically switching to specialized model types\n(YOLOWorld or YOLOE) based on the model filename. It supports various computer vision tasks including object\ndetection, segmentation, classification, pose estimation, and oriented bounding box detection.\n\nAttributes:\n model: The loaded YOLO model instance.\n task: The task type (detect, segment, classify, pose, obb).\n overrides: Configuration overrides for the model.\n\nMethods:\n __init__: Initialize a YOLO model with automatic type detection.\n task_map: Map tasks to their corresponding model, trainer, validator, and predictor classes.\n\nExamples:\n Load a pretrained YOLOv11n detection model\n >>> model = YOLO(\"yolo11n.pt\")\n\n Load a pretrained YOLO11n segmentation model\n >>> model = YOLO(\"yolo11n-seg.pt\")\n\n Initialize from a YAML configuration\n >>> model = YOLO(\"yolo11n.yaml\")", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "pathlib.Path", "typing.Any", "typing.Dict", "typing.List", "typing.Optional", "typing.Union", "ultralytics.data.build.load_inference_source", "ultralytics.engine.model.Model", "ultralytics.models.yolo", "ultralytics.nn.tasks.ClassificationModel", "ultralytics.nn.tasks.DetectionModel", "ultralytics.nn.tasks.OBBModel", "ultralytics.nn.tasks.PoseModel", "ultralytics.nn.tasks.SegmentationModel", "ultralytics.nn.tasks.WorldModel", "ultralytics.nn.tasks.YOLOEModel", "ultralytics.nn.tasks.YOLOESegModel", "ultralytics.utils.ROOT", "ultralytics.utils.YAML", "ultralytics.RTDETR", "Model" ], "chunk_id": "class_YOLO_33c1b65a" }, { "content": "class YOLOWorld(Model):\n \"\"\"\n YOLO-World object detection model.\n\n YOLO-World is an open-vocabulary object detection model that can detect objects based on text descriptions\n without requiring training on specific classes. It extends the YOLO architecture to support real-time\n open-vocabulary detection.\n\n Attributes:\n model: The loaded YOLO-World model instance.\n task: Always set to 'detect' for object detection.\n overrides: Configuration overrides for the model.\n\n Methods:\n __init__: Initialize YOLOv8-World model with a pre-trained model file.\n task_map: Map tasks to their corresponding model, trainer, validator, and predictor classes.\n set_classes: Set the model's class names for detection.\n\n Examples:\n Load a YOLOv8-World model\n >>> model = YOLOWorld(\"yolov8s-world.pt\")\n\n Set custom classes for detection\n >>> model.set_classes([\"person\", \"car\", \"bicycle\"])\n \"\"\"\n\n def __init__(self, model: Union[str, Path] = \"yolov8s-world.pt\", verbose: bool = False) -> None:\n \"\"\"\n Initialize YOLOv8-World model with a pre-trained model file.\n\n Loads a YOLOv8-World model for object detection. If no custom class names are provided, it assigns default\n COCO class names.\n\n Args:\n model (str | Path): Path to the pre-trained model file. Supports *.pt and *.yaml formats.\n verbose (bool): If True, prints additional information during initialization.\n \"\"\"\n super().__init__(model=model, task=\"detect\", verbose=verbose)\n\n # Assign default COCO class names when there are no custom names\n if not hasattr(self.model, \"names\"):\n self.model.names = YAML.load(ROOT / \"cfg/datasets/coco8.yaml\").get(\"names\")\n\n @property\n def task_map(self) -> Dict[str, Dict[str, Any]]:\n \"\"\"Map head to model, validator, and predictor classes.\"\"\"\n return {\n \"detect\": {\n \"model\": WorldModel,\n \"validator\": yolo.detect.DetectionValidator,\n \"predictor\": yolo.detect.DetectionPredictor,\n \"trainer\": yolo.world.WorldTrainer,\n }\n }\n\n def set_classes(self, classes: List[str]) -> None:\n \"\"\"\n Set the model's class names for detection.\n\n Args:\n classes (List[str]): A list of categories i.e. [\"person\"].\n \"\"\"\n self.model.set_classes(classes)\n # Remove background if it's given\n background = \" \"\n if background in classes:\n classes.remove(background)\n self.model.names = classes\n\n # Reset method class names\n if self.predictor:\n self.predictor.model.names = classes", "chunk_type": "class", "name": "YOLOWorld", "file_path": "ultralytics\\ultralytics\\models\\yolo\\model.py", "start_line": 124, "end_line": 195, "start_col": 0, "end_col": 48, "parent_name": null, "docstring": "YOLO-World object detection model.\n\nYOLO-World is an open-vocabulary object detection model that can detect objects based on text descriptions\nwithout requiring training on specific classes. It extends the YOLO architecture to support real-time\nopen-vocabulary detection.\n\nAttributes:\n model: The loaded YOLO-World model instance.\n task: Always set to 'detect' for object detection.\n overrides: Configuration overrides for the model.\n\nMethods:\n __init__: Initialize YOLOv8-World model with a pre-trained model file.\n task_map: Map tasks to their corresponding model, trainer, validator, and predictor classes.\n set_classes: Set the model's class names for detection.\n\nExamples:\n Load a YOLOv8-World model\n >>> model = YOLOWorld(\"yolov8s-world.pt\")\n\n Set custom classes for detection\n >>> model.set_classes([\"person\", \"car\", \"bicycle\"])", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "pathlib.Path", "typing.Any", "typing.Dict", "typing.List", "typing.Optional", "typing.Union", "ultralytics.data.build.load_inference_source", "ultralytics.engine.model.Model", "ultralytics.models.yolo", "ultralytics.nn.tasks.ClassificationModel", "ultralytics.nn.tasks.DetectionModel", "ultralytics.nn.tasks.OBBModel", "ultralytics.nn.tasks.PoseModel", "ultralytics.nn.tasks.SegmentationModel", "ultralytics.nn.tasks.WorldModel", "ultralytics.nn.tasks.YOLOEModel", "ultralytics.nn.tasks.YOLOESegModel", "ultralytics.utils.ROOT", "ultralytics.utils.YAML", "ultralytics.RTDETR", "Model" ], "chunk_id": "class_YOLOWorld_17b0b132" }, { "content": "class YOLOE(Model):\n \"\"\"\n YOLOE object detection and segmentation model.\n\n YOLOE is an enhanced YOLO model that supports both object detection and instance segmentation tasks with\n improved performance and additional features like visual and text positional embeddings.\n\n Attributes:\n model: The loaded YOLOE model instance.\n task: The task type (detect or segment).\n overrides: Configuration overrides for the model.\n\n Methods:\n __init__: Initialize YOLOE model with a pre-trained model file.\n task_map: Map tasks to their corresponding model, trainer, validator, and predictor classes.\n get_text_pe: Get text positional embeddings for the given texts.\n get_visual_pe: Get visual positional embeddings for the given image and visual features.\n set_vocab: Set vocabulary and class names for the YOLOE model.\n get_vocab: Get vocabulary for the given class names.\n set_classes: Set the model's class names and embeddings for detection.\n val: Validate the model using text or visual prompts.\n predict: Run prediction on images, videos, directories, streams, etc.\n\n Examples:\n Load a YOLOE detection model\n >>> model = YOLOE(\"yoloe-11s-seg.pt\")\n\n Set vocabulary and class names\n >>> model.set_vocab([\"person\", \"car\", \"dog\"], [\"person\", \"car\", \"dog\"])\n\n Predict with visual prompts\n >>> prompts = {\"bboxes\": [[10, 20, 100, 200]], \"cls\": [\"person\"]}\n >>> results = model.predict(\"image.jpg\", visual_prompts=prompts)\n \"\"\"\n\n def __init__(\n self, model: Union[str, Path] = \"yoloe-11s-seg.pt\", task: Optional[str] = None, verbose: bool = False\n ) -> None:\n \"\"\"\n Initialize YOLOE model with a pre-trained model file.\n\n Args:\n model (str | Path): Path to the pre-trained model file. Supports *.pt and *.yaml formats.\n task (str, optional): Task type for the model. Auto-detected if None.\n verbose (bool): If True, prints additional information during initialization.\n \"\"\"\n super().__init__(model=model, task=task, verbose=verbose)\n\n @property\n def task_map(self) -> Dict[str, Dict[str, Any]]:\n \"\"\"Map head to model, validator, and predictor classes.\"\"\"\n return {\n \"detect\": {\n \"model\": YOLOEModel,\n \"validator\": yolo.yoloe.YOLOEDetectValidator,\n \"predictor\": yolo.detect.DetectionPredictor,\n \"trainer\": yolo.yoloe.YOLOETrainer,\n },\n \"segment\": {\n \"model\": YOLOESegModel,\n \"validator\": yolo.yoloe.YOLOESegValidator,\n \"predictor\": yolo.segment.SegmentationPredictor,\n \"trainer\": yolo.yoloe.YOLOESegTrainer,\n },\n }\n\n def get_text_pe(self, texts):\n \"\"\"Get text positional embeddings for the given texts.\"\"\"\n assert isinstance(self.model, YOLOEModel)\n return self.model.get_text_pe(texts)\n\n def get_visual_pe(self, img, visual):\n \"\"\"\n Get visual positional embeddings for the given image and visual features.\n\n This method extracts positional embeddings from visual features based on the input image. It requires\n that the model is an instance of YOLOEModel.\n\n Args:\n img (torch.Tensor): Input image tensor.\n visual (torch.Tensor): Visual features extracted from the image.\n\n Returns:\n (torch.Tensor): Visual positional embeddings.\n\n Examples:\n >>> model = YOLOE(\"yoloe-11s-seg.pt\")\n >>> img = torch.rand(1, 3, 640, 640)\n >>> visual_features = torch.rand(1, 1, 80, 80)\n >>> pe = model.get_visual_pe(img, visual_features)\n \"\"\"\n assert isinstance(self.model, YOLOEModel)\n return self.model.get_visual_pe(img, visual)\n\n def set_vocab(self, vocab: List[str], names: List[str]) -> None:\n \"\"\"\n Set vocabulary and class names for the YOLOE model.\n\n This method configures the vocabulary and class names used by the model for text processing and\n classification tasks. The model must be an instance of YOLOEModel.\n\n Args:\n vocab (List[str]): Vocabulary list containing tokens or words used by the model for text processing.\n names (List[str]): List of class names that the model can detect or classify.\n\n Raises:\n AssertionError: If the model is not an instance of YOLOEModel.\n\n Examples:\n >>> model = YOLOE(\"yoloe-11s-seg.pt\")\n >>> model.set_vocab([\"person\", \"car\", \"dog\"], [\"person\", \"car\", \"dog\"])\n \"\"\"\n assert isinstance(self.model, YOLOEModel)\n self.model.set_vocab(vocab, names=names)\n\n def get_vocab(self, names):\n \"\"\"Get vocabulary for the given class names.\"\"\"\n assert isinstance(self.model, YOLOEModel)\n return self.model.get_vocab(names)\n\n def set_classes(self, classes: List[str], embeddings) -> None:\n \"\"\"\n Set the model's class names and embeddings for detection.\n\n Args:\n classes (List[str]): A list of categories i.e. [\"person\"].\n embeddings (torch.Tensor): Embeddings corresponding to the classes.\n \"\"\"\n assert isinstance(self.model, YOLOEModel)\n self.model.set_classes(classes, embeddings)\n # Verify no background class is present\n assert \" \" not in classes\n self.model.names = classes\n\n # Reset method class names\n if self.predictor:\n self.predictor.model.names = classes\n\n def val(\n self,\n validator=None,\n load_vp: bool = False,\n refer_data: Optional[str] = None,\n **kwargs,\n ):\n \"\"\"\n Validate the model using text or visual prompts.\n\n Args:\n validator (callable, optional): A callable validator function. If None, a default validator is loaded.\n load_vp (bool): Whether to load visual prompts. If False, text prompts are used.\n refer_data (str, optional): Path to the reference data for visual prompts.\n **kwargs (Any): Additional keyword arguments to override default settings.\n\n Returns:\n (dict): Validation statistics containing metrics computed during validation.\n \"\"\"\n custom = {\"rect\": not load_vp} # method defaults\n args = {**self.overrides, **custom, **kwargs, \"mode\": \"val\"} # highest priority args on the right\n\n validator = (validator or self._smart_load(\"validator\"))(args=args, _callbacks=self.callbacks)\n validator(model=self.model, load_vp=load_vp, refer_data=refer_data)\n self.metrics = validator.metrics\n return validator.metrics\n\n def predict(\n self,\n source=None,\n stream: bool = False,\n visual_prompts: Dict[str, List] = {},\n refer_image=None,\n predictor=None,\n **kwargs,\n ):\n \"\"\"\n Run prediction on images, videos, directories, streams, etc.\n\n Args:\n source (str | int | PIL.Image | np.ndarray, optional): Source for prediction. Accepts image paths,\n directory paths, URL/YouTube streams, PIL images, numpy arrays, or webcam indices.\n stream (bool): Whether to stream the prediction results. If True, results are yielded as a\n generator as they are computed.\n visual_prompts (Dict[str, List]): Dictionary containing visual prompts for the model. Must include\n 'bboxes' and 'cls' keys when non-empty.\n refer_image (str | PIL.Image | np.ndarray, optional): Reference image for visual prompts.\n predictor (callable, optional): Custom predictor function. If None, a predictor is automatically\n loaded based on the task.\n **kwargs (Any): Additional keyword arguments passed to the predictor.\n\n Returns:\n (List | generator): List of Results objects or generator of Results objects if stream=True.\n\n Examples:\n >>> model = YOLOE(\"yoloe-11s-seg.pt\")\n >>> results = model.predict(\"path/to/image.jpg\")\n >>> # With visual prompts\n >>> prompts = {\"bboxes\": [[10, 20, 100, 200]], \"cls\": [\"person\"]}\n >>> results = model.predict(\"path/to/image.jpg\", visual_prompts=prompts)\n \"\"\"\n if len(visual_prompts):\n assert \"bboxes\" in visual_prompts and \"cls\" in visual_prompts, (\n f\"Expected 'bboxes' and 'cls' in visual prompts, but got {visual_prompts.keys()}\"\n )\n assert len(visual_prompts[\"bboxes\"]) == len(visual_prompts[\"cls\"]), (\n f\"Expected equal number of bounding boxes and classes, but got {len(visual_prompts['bboxes'])} and \"\n f\"{len(visual_prompts['cls'])} respectively\"\n )\n if not isinstance(self.predictor, yolo.yoloe.YOLOEVPDetectPredictor):\n self.predictor = (predictor or yolo.yoloe.YOLOEVPDetectPredictor)(\n overrides={\n \"task\": self.model.task,\n \"mode\": \"predict\",\n \"save\": False,\n \"verbose\": refer_image is None,\n \"batch\": 1,\n },\n _callbacks=self.callbacks,\n )\n\n num_cls = (\n max(len(set(c)) for c in visual_prompts[\"cls\"])\n if isinstance(source, list) and refer_image is None # means multiple images\n else len(set(visual_prompts[\"cls\"]))\n )\n self.model.model[-1].nc = num_cls\n self.model.names = [f\"object{i}\" for i in range(num_cls)]\n self.predictor.set_prompts(visual_prompts.copy())\n self.predictor.setup_model(model=self.model)\n\n if refer_image is None and source is not None:\n dataset = load_inference_source(source)\n if dataset.mode in {\"video\", \"stream\"}:\n # NOTE: set the first frame as refer image for videos/streams inference\n refer_image = next(iter(dataset))[1][0]\n if refer_image is not None:\n vpe = self.predictor.get_vpe(refer_image)\n self.model.set_classes(self.model.names, vpe)\n self.task = \"segment\" if isinstance(self.predictor, yolo.segment.SegmentationPredictor) else \"detect\"\n self.predictor = None # reset predictor\n elif isinstance(self.predictor, yolo.yoloe.YOLOEVPDetectPredictor):\n self.predictor = None # reset predictor if no visual prompts\n\n return super().predict(source, stream, **kwargs)", "chunk_type": "class", "name": "YOLOE", "file_path": "ultralytics\\ultralytics\\models\\yolo\\model.py", "start_line": 198, "end_line": 440, "start_col": 0, "end_col": 56, "parent_name": null, "docstring": "YOLOE object detection and segmentation model.\n\nYOLOE is an enhanced YOLO model that supports both object detection and instance segmentation tasks with\nimproved performance and additional features like visual and text positional embeddings.\n\nAttributes:\n model: The loaded YOLOE model instance.\n task: The task type (detect or segment).\n overrides: Configuration overrides for the model.\n\nMethods:\n __init__: Initialize YOLOE model with a pre-trained model file.\n task_map: Map tasks to their corresponding model, trainer, validator, and predictor classes.\n get_text_pe: Get text positional embeddings for the given texts.\n get_visual_pe: Get visual positional embeddings for the given image and visual features.\n set_vocab: Set vocabulary and class names for the YOLOE model.\n get_vocab: Get vocabulary for the given class names.\n set_classes: Set the model's class names and embeddings for detection.\n val: Validate the model using text or visual prompts.\n predict: Run prediction on images, videos, directories, streams, etc.\n\nExamples:\n Load a YOLOE detection model\n >>> model = YOLOE(\"yoloe-11s-seg.pt\")\n\n Set vocabulary and class names\n >>> model.set_vocab([\"person\", \"car\", \"dog\"], [\"person\", \"car\", \"dog\"])\n\n Predict with visual prompts\n >>> prompts = {\"bboxes\": [[10, 20, 100, 200]], \"cls\": [\"person\"]}\n >>> results = model.predict(\"image.jpg\", visual_prompts=prompts)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "pathlib.Path", "typing.Any", "typing.Dict", "typing.List", "typing.Optional", "typing.Union", "ultralytics.data.build.load_inference_source", "ultralytics.engine.model.Model", "ultralytics.models.yolo", "ultralytics.nn.tasks.ClassificationModel", "ultralytics.nn.tasks.DetectionModel", "ultralytics.nn.tasks.OBBModel", "ultralytics.nn.tasks.PoseModel", "ultralytics.nn.tasks.SegmentationModel", "ultralytics.nn.tasks.WorldModel", "ultralytics.nn.tasks.YOLOEModel", "ultralytics.nn.tasks.YOLOESegModel", "ultralytics.utils.ROOT", "ultralytics.utils.YAML", "ultralytics.RTDETR", "Model" ], "chunk_id": "class_YOLOE_237af088" }, { "content": "from ultralytics.models.yolo import classify, detect, obb, pose, segment, world, yoloe", "chunk_type": "import", "name": "classify, detect, obb, pose, segment, world, yoloe", "file_path": "ultralytics\\ultralytics\\models\\yolo\\__init__.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 86, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_classify, detect, obb, pose, segment, world, yoloe_bf01e42a" }, { "content": "from .model import YOLO, YOLOE, YOLOWorld", "chunk_type": "import", "name": "YOLO, YOLOE, YOLOWorld", "file_path": "ultralytics\\ultralytics\\models\\yolo\\__init__.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 41, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_YOLO, YOLOE, YOLOWorld_e96d119e" }, { "content": "__all__ = \"classify\", \"segment\", \"detect\", \"pose\", \"obb\", \"world\", \"yoloe\", \"YOLO\", \"YOLOWorld\", \"YOLOE\"", "chunk_type": "variable", "name": "__all__", "file_path": "ultralytics\\ultralytics\\models\\yolo\\__init__.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 104, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable___all___f3807658" }, { "content": "import copy", "chunk_type": "import", "name": "copy", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\blocks.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 11, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_copy_999a2db0" }, { "content": "import math", "chunk_type": "import", "name": "math", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\blocks.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 11, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_math_6bc7677f" }, { "content": "from functools import partial", "chunk_type": "import", "name": "partial", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\blocks.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 29, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_partial_43112863" }, { "content": "from typing import Any, Optional, Tuple, Type, Union", "chunk_type": "import", "name": "Any, Optional, Tuple, Type, Union", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\blocks.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 52, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Any, Optional, Tuple, Type, Union_5730cfac" }, { "content": "import numpy as np", "chunk_type": "import", "name": "numpy", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\blocks.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_numpy_06181710" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\blocks.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_621563b7" }, { "content": "import torch.nn.functional as F", "chunk_type": "import", "name": "torch.nn.functional", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\blocks.py", "start_line": 10, "end_line": 10, "start_col": 0, "end_col": 31, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch.nn.functional_48ff619f" }, { "content": "from torch import Tensor, nn", "chunk_type": "import", "name": "Tensor, nn", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\blocks.py", "start_line": 11, "end_line": 11, "start_col": 0, "end_col": 28, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Tensor, nn_82500d82" }, { "content": "from ultralytics.nn.modules import MLP, LayerNorm2d, MLPBlock", "chunk_type": "import", "name": "MLP, LayerNorm2d, MLPBlock", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\blocks.py", "start_line": 13, "end_line": 13, "start_col": 0, "end_col": 61, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_MLP, LayerNorm2d, MLPBlock_4e9f6fe4" }, { "content": "from .transformer import Attention, TwoWayAttentionBlock, TwoWayTransformer", "chunk_type": "import", "name": "Attention, TwoWayAttentionBlock, TwoWayTransformer", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\blocks.py", "start_line": 15, "end_line": 15, "start_col": 0, "end_col": 75, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Attention, TwoWayAttentionBlock, TwoWayTransformer_c782685c" }, { "content": "from .utils import add_decomposed_rel_pos, apply_rotary_enc, compute_axial_cis, window_partition, window_unpartition", "chunk_type": "import", "name": "add_decomposed_rel_pos, apply_rotary_enc, compute_axial_cis, window_partition, window_unpartition", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\blocks.py", "start_line": 16, "end_line": 16, "start_col": 0, "end_col": 116, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_add_decomposed_rel_pos, apply_rotary_enc, compute_axial_cis, window_partition, window_unpartition_0195c9a6" }, { "content": "class DropPath(nn.Module):\n \"\"\"\n Implements stochastic depth regularization for neural networks during training.\n\n Attributes:\n drop_prob (float): Probability of dropping a path during training.\n scale_by_keep (bool): Whether to scale the output by the keep probability.\n\n Methods:\n forward: Applies stochastic depth to input tensor during training, with optional scaling.\n\n Examples:\n >>> drop_path = DropPath(drop_prob=0.2, scale_by_keep=True)\n >>> x = torch.randn(32, 64, 224, 224)\n >>> output = drop_path(x)\n \"\"\"\n\n def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):\n \"\"\"Initialize DropPath module for stochastic depth regularization during training.\"\"\"\n super().__init__()\n self.drop_prob = drop_prob\n self.scale_by_keep = scale_by_keep\n\n def forward(self, x: Tensor) -> Tensor:\n \"\"\"Apply stochastic depth to input tensor during training, with optional scaling.\"\"\"\n if self.drop_prob == 0.0 or not self.training:\n return x\n keep_prob = 1 - self.drop_prob\n shape = (x.shape[0],) + (1,) * (x.ndim - 1)\n random_tensor = x.new_empty(shape).bernoulli_(keep_prob)\n if keep_prob > 0.0 and self.scale_by_keep:\n random_tensor.div_(keep_prob)\n return x * random_tensor", "chunk_type": "class", "name": "DropPath", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\blocks.py", "start_line": 19, "end_line": 51, "start_col": 0, "end_col": 32, "parent_name": null, "docstring": "Implements stochastic depth regularization for neural networks during training.\n\nAttributes:\n drop_prob (float): Probability of dropping a path during training.\n scale_by_keep (bool): Whether to scale the output by the keep probability.\n\nMethods:\n forward: Applies stochastic depth to input tensor during training, with optional scaling.\n\nExamples:\n >>> drop_path = DropPath(drop_prob=0.2, scale_by_keep=True)\n >>> x = torch.randn(32, 64, 224, 224)\n >>> output = drop_path(x)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "copy", "math", "functools.partial", "typing.Any", "typing.Optional", "typing.Tuple", "typing.Type", "typing.Union", "numpy", "torch", "torch.nn.functional", "torch.Tensor", "torch.nn", "ultralytics.nn.modules.MLP", "ultralytics.nn.modules.LayerNorm2d", "ultralytics.nn.modules.MLPBlock", "transformer.Attention", "transformer.TwoWayAttentionBlock", "transformer.TwoWayTransformer", "utils.add_decomposed_rel_pos", "utils.apply_rotary_enc", "utils.compute_axial_cis", "utils.window_partition", "utils.window_unpartition", "nn.Module" ], "chunk_id": "class_DropPath_b0d4151c" }, { "content": "class MaskDownSampler(nn.Module):\n \"\"\"\n A mask downsampling and embedding module for efficient processing of input masks.\n\n This class implements a mask downsampler that progressively reduces the spatial dimensions of input masks\n while expanding their channel dimensions using convolutional layers, layer normalization, and activation\n functions.\n\n Attributes:\n encoder (nn.Sequential): A sequential container of convolutional layers, layer normalization, and\n activation functions for downsampling and embedding masks.\n\n Methods:\n forward: Downsamples and encodes input mask to embed_dim channels.\n\n Examples:\n >>> mask_downsampler = MaskDownSampler(embed_dim=256, kernel_size=4, stride=4, padding=0, total_stride=16)\n >>> input_mask = torch.randn(1, 1, 256, 256)\n >>> output = mask_downsampler(input_mask)\n >>> print(output.shape)\n torch.Size([1, 256, 16, 16])\n \"\"\"\n\n def __init__(\n self,\n embed_dim: int = 256,\n kernel_size: int = 4,\n stride: int = 4,\n padding: int = 0,\n total_stride: int = 16,\n activation: Type[nn.Module] = nn.GELU,\n ):\n \"\"\"Initialize a mask downsampler module for progressive downsampling and channel expansion.\"\"\"\n super().__init__()\n num_layers = int(math.log2(total_stride) // math.log2(stride))\n assert stride**num_layers == total_stride\n self.encoder = nn.Sequential()\n mask_in_chans, mask_out_chans = 1, 1\n for _ in range(num_layers):\n mask_out_chans = mask_in_chans * (stride**2)\n self.encoder.append(\n nn.Conv2d(\n mask_in_chans,\n mask_out_chans,\n kernel_size=kernel_size,\n stride=stride,\n padding=padding,\n )\n )\n self.encoder.append(LayerNorm2d(mask_out_chans))\n self.encoder.append(activation())\n mask_in_chans = mask_out_chans\n\n self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1))\n\n def forward(self, x: Tensor) -> Tensor:\n \"\"\"Downsample and encode input mask to embed_dim channels using convolutional layers and LayerNorm2d.\"\"\"\n return self.encoder(x)", "chunk_type": "class", "name": "MaskDownSampler", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\blocks.py", "start_line": 54, "end_line": 111, "start_col": 0, "end_col": 30, "parent_name": null, "docstring": "A mask downsampling and embedding module for efficient processing of input masks.\n\nThis class implements a mask downsampler that progressively reduces the spatial dimensions of input masks\nwhile expanding their channel dimensions using convolutional layers, layer normalization, and activation\nfunctions.\n\nAttributes:\n encoder (nn.Sequential): A sequential container of convolutional layers, layer normalization, and\n activation functions for downsampling and embedding masks.\n\nMethods:\n forward: Downsamples and encodes input mask to embed_dim channels.\n\nExamples:\n >>> mask_downsampler = MaskDownSampler(embed_dim=256, kernel_size=4, stride=4, padding=0, total_stride=16)\n >>> input_mask = torch.randn(1, 1, 256, 256)\n >>> output = mask_downsampler(input_mask)\n >>> print(output.shape)\n torch.Size([1, 256, 16, 16])", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "copy", "math", "functools.partial", "typing.Any", "typing.Optional", "typing.Tuple", "typing.Type", "typing.Union", "numpy", "torch", "torch.nn.functional", "torch.Tensor", "torch.nn", "ultralytics.nn.modules.MLP", "ultralytics.nn.modules.LayerNorm2d", "ultralytics.nn.modules.MLPBlock", "transformer.Attention", "transformer.TwoWayAttentionBlock", "transformer.TwoWayTransformer", "utils.add_decomposed_rel_pos", "utils.apply_rotary_enc", "utils.compute_axial_cis", "utils.window_partition", "utils.window_unpartition", "nn.Module" ], "chunk_id": "class_MaskDownSampler_c51bf8d6" }, { "content": "class CXBlock(nn.Module):\n \"\"\"\n ConvNeXt Block for efficient feature extraction in convolutional neural networks.\n\n This block implements a modified version of the ConvNeXt architecture, offering improved performance and\n flexibility in feature extraction.\n\n Attributes:\n dwconv (nn.Conv2d): Depthwise or standard 2D convolution layer.\n norm (LayerNorm2d): Layer normalization applied to channels.\n pwconv1 (nn.Linear): First pointwise convolution implemented as a linear layer.\n act (nn.GELU): GELU activation function.\n pwconv2 (nn.Linear): Second pointwise convolution implemented as a linear layer.\n gamma (nn.Parameter | None): Learnable scale parameter for layer scaling.\n drop_path (nn.Module): DropPath layer for stochastic depth regularization.\n\n Methods:\n forward: Processes the input tensor through the ConvNeXt block.\n\n Examples:\n >>> import torch\n >>> x = torch.randn(1, 64, 56, 56)\n >>> block = CXBlock(dim=64, kernel_size=7, padding=3)\n >>> output = block(x)\n >>> print(output.shape)\n torch.Size([1, 64, 56, 56])\n \"\"\"\n\n def __init__(\n self,\n dim: int,\n kernel_size: int = 7,\n padding: int = 3,\n drop_path: float = 0.0,\n layer_scale_init_value: float = 1e-6,\n use_dwconv: bool = True,\n ):\n \"\"\"\n Initialize a ConvNeXt Block for efficient feature extraction in convolutional neural networks.\n\n This block implements a modified version of the ConvNeXt architecture, offering improved performance and\n flexibility in feature extraction.\n\n Args:\n dim (int): Number of input channels.\n kernel_size (int): Size of the convolutional kernel.\n padding (int): Padding size for the convolution.\n drop_path (float): Stochastic depth rate.\n layer_scale_init_value (float): Initial value for Layer Scale.\n use_dwconv (bool): Whether to use depthwise convolution.\n\n Examples:\n >>> block = CXBlock(dim=64, kernel_size=7, padding=3)\n >>> x = torch.randn(1, 64, 32, 32)\n >>> output = block(x)\n >>> print(output.shape)\n torch.Size([1, 64, 32, 32])\n \"\"\"\n super().__init__()\n self.dwconv = nn.Conv2d(\n dim,\n dim,\n kernel_size=kernel_size,\n padding=padding,\n groups=dim if use_dwconv else 1,\n ) # depthwise conv\n self.norm = LayerNorm2d(dim, eps=1e-6)\n self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers\n self.act = nn.GELU()\n self.pwconv2 = nn.Linear(4 * dim, dim)\n self.gamma = (\n nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)\n if layer_scale_init_value > 0\n else None\n )\n self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()\n\n def forward(self, x: Tensor) -> Tensor:\n \"\"\"Apply ConvNeXt block operations to input tensor, including convolutions and residual connection.\"\"\"\n input = x\n x = self.dwconv(x)\n x = self.norm(x)\n x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)\n x = self.pwconv1(x)\n x = self.act(x)\n x = self.pwconv2(x)\n if self.gamma is not None:\n x = self.gamma * x\n x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)\n\n x = input + self.drop_path(x)\n return x", "chunk_type": "class", "name": "CXBlock", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\blocks.py", "start_line": 114, "end_line": 205, "start_col": 0, "end_col": 16, "parent_name": null, "docstring": "ConvNeXt Block for efficient feature extraction in convolutional neural networks.\n\nThis block implements a modified version of the ConvNeXt architecture, offering improved performance and\nflexibility in feature extraction.\n\nAttributes:\n dwconv (nn.Conv2d): Depthwise or standard 2D convolution layer.\n norm (LayerNorm2d): Layer normalization applied to channels.\n pwconv1 (nn.Linear): First pointwise convolution implemented as a linear layer.\n act (nn.GELU): GELU activation function.\n pwconv2 (nn.Linear): Second pointwise convolution implemented as a linear layer.\n gamma (nn.Parameter | None): Learnable scale parameter for layer scaling.\n drop_path (nn.Module): DropPath layer for stochastic depth regularization.\n\nMethods:\n forward: Processes the input tensor through the ConvNeXt block.\n\nExamples:\n >>> import torch\n >>> x = torch.randn(1, 64, 56, 56)\n >>> block = CXBlock(dim=64, kernel_size=7, padding=3)\n >>> output = block(x)\n >>> print(output.shape)\n torch.Size([1, 64, 56, 56])", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "copy", "math", "functools.partial", "typing.Any", "typing.Optional", "typing.Tuple", "typing.Type", "typing.Union", "numpy", "torch", "torch.nn.functional", "torch.Tensor", "torch.nn", "ultralytics.nn.modules.MLP", "ultralytics.nn.modules.LayerNorm2d", "ultralytics.nn.modules.MLPBlock", "transformer.Attention", "transformer.TwoWayAttentionBlock", "transformer.TwoWayTransformer", "utils.add_decomposed_rel_pos", "utils.apply_rotary_enc", "utils.compute_axial_cis", "utils.window_partition", "utils.window_unpartition", "nn.Module" ], "chunk_id": "class_CXBlock_2e708231" }, { "content": "class Fuser(nn.Module):\n \"\"\"\n A module for fusing features through multiple layers of a neural network.\n\n This class applies a series of identical layers to an input tensor, optionally projecting the input first.\n\n Attributes:\n proj (nn.Module): An optional input projection layer. Identity if no projection is needed.\n layers (nn.ModuleList): A list of identical layers to be applied sequentially.\n\n Methods:\n forward: Applies the fuser to an input tensor.\n\n Examples:\n >>> layer = CXBlock(dim=256)\n >>> fuser = Fuser(layer, num_layers=3, dim=256, input_projection=True)\n >>> x = torch.randn(1, 256, 32, 32)\n >>> output = fuser(x)\n >>> print(output.shape)\n torch.Size([1, 256, 32, 32])\n \"\"\"\n\n def __init__(self, layer: nn.Module, num_layers: int, dim: Optional[int] = None, input_projection: bool = False):\n \"\"\"\n Initialize the Fuser module for feature fusion through multiple layers.\n\n This module creates a sequence of identical layers and optionally applies an input projection.\n\n Args:\n layer (nn.Module): The layer to be replicated in the fuser.\n num_layers (int): The number of times to replicate the layer.\n dim (int | None): The dimension for input projection, if used.\n input_projection (bool): Whether to use input projection.\n\n Examples:\n >>> layer = nn.Linear(64, 64)\n >>> fuser = Fuser(layer, num_layers=3, dim=64, input_projection=True)\n >>> input_tensor = torch.randn(1, 64)\n >>> output = fuser(input_tensor)\n \"\"\"\n super().__init__()\n self.proj = nn.Identity()\n self.layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(num_layers)])\n\n if input_projection:\n assert dim is not None\n self.proj = nn.Conv2d(dim, dim, kernel_size=1)\n\n def forward(self, x: Tensor) -> Tensor:\n \"\"\"Apply a series of layers to the input tensor, optionally projecting it first.\"\"\"\n x = self.proj(x)\n for layer in self.layers:\n x = layer(x)\n return x", "chunk_type": "class", "name": "Fuser", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\blocks.py", "start_line": 208, "end_line": 261, "start_col": 0, "end_col": 16, "parent_name": null, "docstring": "A module for fusing features through multiple layers of a neural network.\n\nThis class applies a series of identical layers to an input tensor, optionally projecting the input first.\n\nAttributes:\n proj (nn.Module): An optional input projection layer. Identity if no projection is needed.\n layers (nn.ModuleList): A list of identical layers to be applied sequentially.\n\nMethods:\n forward: Applies the fuser to an input tensor.\n\nExamples:\n >>> layer = CXBlock(dim=256)\n >>> fuser = Fuser(layer, num_layers=3, dim=256, input_projection=True)\n >>> x = torch.randn(1, 256, 32, 32)\n >>> output = fuser(x)\n >>> print(output.shape)\n torch.Size([1, 256, 32, 32])", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "copy", "math", "functools.partial", "typing.Any", "typing.Optional", "typing.Tuple", "typing.Type", "typing.Union", "numpy", "torch", "torch.nn.functional", "torch.Tensor", "torch.nn", "ultralytics.nn.modules.MLP", "ultralytics.nn.modules.LayerNorm2d", "ultralytics.nn.modules.MLPBlock", "transformer.Attention", "transformer.TwoWayAttentionBlock", "transformer.TwoWayTransformer", "utils.add_decomposed_rel_pos", "utils.apply_rotary_enc", "utils.compute_axial_cis", "utils.window_partition", "utils.window_unpartition", "nn.Module" ], "chunk_id": "class_Fuser_76f231f1" }, { "content": "class SAM2TwoWayAttentionBlock(TwoWayAttentionBlock):\n \"\"\"\n A two-way attention block for performing self-attention and cross-attention in both directions.\n\n This block extends the TwoWayAttentionBlock and consists of four main components: self-attention on\n sparse inputs, cross-attention from sparse to dense inputs, an MLP block on sparse inputs, and\n cross-attention from dense to sparse inputs.\n\n Attributes:\n self_attn (Attention): Self-attention layer for queries.\n norm1 (nn.LayerNorm): Layer normalization after the first attention block.\n cross_attn_token_to_image (Attention): Cross-attention layer from queries to keys.\n norm2 (nn.LayerNorm): Layer normalization after the second attention block.\n mlp (MLP): MLP block for transforming query embeddings.\n norm3 (nn.LayerNorm): Layer normalization after the MLP block.\n norm4 (nn.LayerNorm): Layer normalization after the third attention block.\n cross_attn_image_to_token (Attention): Cross-attention layer from keys to queries.\n skip_first_layer_pe (bool): Flag to skip positional encoding in the first layer.\n\n Methods:\n forward: Processes input through the attention blocks and MLP.\n\n Examples:\n >>> block = SAM2TwoWayAttentionBlock(embedding_dim=256, num_heads=8)\n >>> sparse_input = torch.randn(1, 100, 256)\n >>> dense_input = torch.randn(1, 256, 16, 16)\n >>> sparse_output, dense_output = block(sparse_input, dense_input)\n \"\"\"\n\n def __init__(\n self,\n embedding_dim: int,\n num_heads: int,\n mlp_dim: int = 2048,\n activation: Type[nn.Module] = nn.ReLU,\n attention_downsample_rate: int = 2,\n skip_first_layer_pe: bool = False,\n ) -> None:\n \"\"\"\n Initialize a SAM2TwoWayAttentionBlock for performing self-attention and cross-attention in two directions.\n\n This block extends the TwoWayAttentionBlock and consists of four main components: self-attention on sparse\n inputs, cross-attention from sparse to dense inputs, an MLP block on sparse inputs, and cross-attention\n from dense to sparse inputs.\n\n Args:\n embedding_dim (int): The channel dimension of the embeddings.\n num_heads (int): The number of heads in the attention layers.\n mlp_dim (int): The hidden dimension of the MLP block.\n activation (Type[nn.Module]): The activation function of the MLP block.\n attention_downsample_rate (int): The downsample rate for attention computations.\n skip_first_layer_pe (bool): Whether to skip the positional encoding in the first layer.\n\n Examples:\n >>> block = SAM2TwoWayAttentionBlock(embedding_dim=256, num_heads=8, mlp_dim=2048)\n >>> sparse_inputs = torch.randn(1, 100, 256)\n >>> dense_inputs = torch.randn(1, 256, 32, 32)\n >>> sparse_outputs, dense_outputs = block(sparse_inputs, dense_inputs)\n \"\"\"\n super().__init__(embedding_dim, num_heads, mlp_dim, activation, attention_downsample_rate, skip_first_layer_pe)\n self.mlp = MLP(embedding_dim, mlp_dim, embedding_dim, num_layers=2, act=activation)", "chunk_type": "class", "name": "SAM2TwoWayAttentionBlock", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\blocks.py", "start_line": 264, "end_line": 324, "start_col": 0, "end_col": 91, "parent_name": null, "docstring": "A two-way attention block for performing self-attention and cross-attention in both directions.\n\nThis block extends the TwoWayAttentionBlock and consists of four main components: self-attention on\nsparse inputs, cross-attention from sparse to dense inputs, an MLP block on sparse inputs, and\ncross-attention from dense to sparse inputs.\n\nAttributes:\n self_attn (Attention): Self-attention layer for queries.\n norm1 (nn.LayerNorm): Layer normalization after the first attention block.\n cross_attn_token_to_image (Attention): Cross-attention layer from queries to keys.\n norm2 (nn.LayerNorm): Layer normalization after the second attention block.\n mlp (MLP): MLP block for transforming query embeddings.\n norm3 (nn.LayerNorm): Layer normalization after the MLP block.\n norm4 (nn.LayerNorm): Layer normalization after the third attention block.\n cross_attn_image_to_token (Attention): Cross-attention layer from keys to queries.\n skip_first_layer_pe (bool): Flag to skip positional encoding in the first layer.\n\nMethods:\n forward: Processes input through the attention blocks and MLP.\n\nExamples:\n >>> block = SAM2TwoWayAttentionBlock(embedding_dim=256, num_heads=8)\n >>> sparse_input = torch.randn(1, 100, 256)\n >>> dense_input = torch.randn(1, 256, 16, 16)\n >>> sparse_output, dense_output = block(sparse_input, dense_input)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "copy", "math", "functools.partial", "typing.Any", "typing.Optional", "typing.Tuple", "typing.Type", "typing.Union", "numpy", "torch", "torch.nn.functional", "torch.Tensor", "torch.nn", "ultralytics.nn.modules.MLP", "ultralytics.nn.modules.LayerNorm2d", "ultralytics.nn.modules.MLPBlock", "transformer.Attention", "transformer.TwoWayAttentionBlock", "transformer.TwoWayTransformer", "utils.add_decomposed_rel_pos", "utils.apply_rotary_enc", "utils.compute_axial_cis", "utils.window_partition", "utils.window_unpartition", "TwoWayAttentionBlock" ], "chunk_id": "class_SAM2TwoWayAttentionBlock_a7497d58" }, { "content": "class SAM2TwoWayTransformer(TwoWayTransformer):\n \"\"\"\n A Two-Way Transformer module for simultaneous attention to image and query points.\n\n This class extends the TwoWayTransformer, implementing a specialized transformer decoder that attends to an\n input image using queries with supplied positional embeddings. It is particularly useful for tasks like\n object detection, image segmentation, and point cloud processing.\n\n Attributes:\n depth (int): Number of layers in the transformer.\n embedding_dim (int): Channel dimension for input embeddings.\n num_heads (int): Number of heads for multihead attention.\n mlp_dim (int): Internal channel dimension for the MLP block.\n layers (nn.ModuleList): List of SAM2TwoWayAttentionBlock layers comprising the transformer.\n final_attn_token_to_image (Attention): Final attention layer from queries to image.\n norm_final_attn (nn.LayerNorm): Layer normalization applied to final queries.\n\n Methods:\n forward: Processes input image embeddings and query embeddings through the transformer.\n\n Examples:\n >>> transformer = SAM2TwoWayTransformer(depth=5, embedding_dim=256, num_heads=8, mlp_dim=2048)\n >>> image_embedding = torch.randn(1, 256, 64, 64)\n >>> query_embedding = torch.randn(1, 100, 256)\n >>> output = transformer(image_embedding, query_embedding)\n >>> print(output[0].shape, output[1].shape)\n torch.Size([1, 100, 256]) torch.Size([1, 256, 64, 64])\n \"\"\"\n\n def __init__(\n self,\n depth: int,\n embedding_dim: int,\n num_heads: int,\n mlp_dim: int,\n activation: Type[nn.Module] = nn.ReLU,\n attention_downsample_rate: int = 2,\n ) -> None:\n \"\"\"\n Initialize a SAM2TwoWayTransformer instance.\n\n This transformer decoder attends to an input image using queries with supplied positional embeddings.\n It is designed for tasks like object detection, image segmentation, and point cloud processing.\n\n Args:\n depth (int): Number of layers in the transformer.\n embedding_dim (int): Channel dimension for the input embeddings.\n num_heads (int): Number of heads for multihead attention. Must divide embedding_dim.\n mlp_dim (int): Channel dimension internal to the MLP block.\n activation (Type[nn.Module]): Activation function to use in the MLP block.\n attention_downsample_rate (int): Downsampling rate for attention computations.\n\n Examples:\n >>> transformer = SAM2TwoWayTransformer(depth=5, embedding_dim=256, num_heads=8, mlp_dim=2048)\n >>> transformer\n SAM2TwoWayTransformer(\n (layers): ModuleList(\n (0-4): 5 x SAM2TwoWayAttentionBlock(...)\n )\n (final_attn_token_to_image): Attention(...)\n (norm_final_attn): LayerNorm(...)\n )\n \"\"\"\n super().__init__(depth, embedding_dim, num_heads, mlp_dim, activation, attention_downsample_rate)\n self.layers = nn.ModuleList()\n for i in range(depth):\n self.layers.append(\n SAM2TwoWayAttentionBlock(\n embedding_dim=embedding_dim,\n num_heads=num_heads,\n mlp_dim=mlp_dim,\n activation=activation,\n attention_downsample_rate=attention_downsample_rate,\n skip_first_layer_pe=(i == 0),\n )\n )", "chunk_type": "class", "name": "SAM2TwoWayTransformer", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\blocks.py", "start_line": 327, "end_line": 402, "start_col": 0, "end_col": 13, "parent_name": null, "docstring": "A Two-Way Transformer module for simultaneous attention to image and query points.\n\nThis class extends the TwoWayTransformer, implementing a specialized transformer decoder that attends to an\ninput image using queries with supplied positional embeddings. It is particularly useful for tasks like\nobject detection, image segmentation, and point cloud processing.\n\nAttributes:\n depth (int): Number of layers in the transformer.\n embedding_dim (int): Channel dimension for input embeddings.\n num_heads (int): Number of heads for multihead attention.\n mlp_dim (int): Internal channel dimension for the MLP block.\n layers (nn.ModuleList): List of SAM2TwoWayAttentionBlock layers comprising the transformer.\n final_attn_token_to_image (Attention): Final attention layer from queries to image.\n norm_final_attn (nn.LayerNorm): Layer normalization applied to final queries.\n\nMethods:\n forward: Processes input image embeddings and query embeddings through the transformer.\n\nExamples:\n >>> transformer = SAM2TwoWayTransformer(depth=5, embedding_dim=256, num_heads=8, mlp_dim=2048)\n >>> image_embedding = torch.randn(1, 256, 64, 64)\n >>> query_embedding = torch.randn(1, 100, 256)\n >>> output = transformer(image_embedding, query_embedding)\n >>> print(output[0].shape, output[1].shape)\n torch.Size([1, 100, 256]) torch.Size([1, 256, 64, 64])", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "copy", "math", "functools.partial", "typing.Any", "typing.Optional", "typing.Tuple", "typing.Type", "typing.Union", "numpy", "torch", "torch.nn.functional", "torch.Tensor", "torch.nn", "ultralytics.nn.modules.MLP", "ultralytics.nn.modules.LayerNorm2d", "ultralytics.nn.modules.MLPBlock", "transformer.Attention", "transformer.TwoWayAttentionBlock", "transformer.TwoWayTransformer", "utils.add_decomposed_rel_pos", "utils.apply_rotary_enc", "utils.compute_axial_cis", "utils.window_partition", "utils.window_unpartition", "TwoWayTransformer" ], "chunk_id": "class_SAM2TwoWayTransformer_14edd00a" }, { "content": "class RoPEAttention(Attention):\n \"\"\"\n Implements rotary position encoding for attention mechanisms in transformer architectures.\n\n This class extends the base Attention class by incorporating Rotary Position Encoding (RoPE) to enhance\n the positional awareness of the attention mechanism.\n\n Attributes:\n compute_cis (Callable): Function to compute axial complex numbers for rotary encoding.\n freqs_cis (torch.Tensor): Precomputed frequency tensor for rotary encoding.\n rope_k_repeat (bool): Flag to repeat query RoPE to match key length for cross-attention to memories.\n\n Methods:\n forward: Applies rotary position encoding and computes attention between query, key, and value tensors.\n\n Examples:\n >>> rope_attn = RoPEAttention(embedding_dim=256, num_heads=8, rope_theta=10000.0, feat_sizes=(32, 32))\n >>> q = torch.randn(1, 1024, 256)\n >>> k = torch.randn(1, 1024, 256)\n >>> v = torch.randn(1, 1024, 256)\n >>> output = rope_attn(q, k, v)\n >>> print(output.shape)\n torch.Size([1, 1024, 256])\n \"\"\"\n\n def __init__(\n self,\n *args,\n rope_theta: float = 10000.0,\n rope_k_repeat: bool = False,\n feat_sizes: Tuple[int, int] = (32, 32), # [w, h] for stride 16 feats at 512 resolution\n **kwargs,\n ):\n \"\"\"Initialize RoPEAttention with rotary position encoding for enhanced positional awareness.\"\"\"\n super().__init__(*args, **kwargs)\n\n self.compute_cis = partial(compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta)\n freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1])\n self.freqs_cis = freqs_cis\n self.rope_k_repeat = rope_k_repeat # repeat q rope to match k length, needed for cross-attention to memories\n\n def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_k_exclude_rope: int = 0) -> torch.Tensor:\n \"\"\"Apply rotary position encoding and compute attention between query, key, and value tensors.\"\"\"\n q = self.q_proj(q)\n k = self.k_proj(k)\n v = self.v_proj(v)\n\n # Separate into heads\n q = self._separate_heads(q, self.num_heads)\n k = self._separate_heads(k, self.num_heads)\n v = self._separate_heads(v, self.num_heads)\n\n # Apply rotary position encoding\n w = h = math.sqrt(q.shape[-2])\n self.freqs_cis = self.freqs_cis.to(q.device)\n if self.freqs_cis.shape[0] != q.shape[-2]:\n self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device)\n if q.shape[-2] != k.shape[-2]:\n assert self.rope_k_repeat\n\n num_k_rope = k.size(-2) - num_k_exclude_rope\n q, k[:, :, :num_k_rope] = apply_rotary_enc(\n q,\n k[:, :, :num_k_rope],\n freqs_cis=self.freqs_cis,\n repeat_freqs_k=self.rope_k_repeat,\n )\n\n # Attention\n _, _, _, c_per_head = q.shape\n attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens\n attn = attn / math.sqrt(c_per_head)\n attn = torch.softmax(attn, dim=-1)\n\n # Get output\n out = attn @ v\n\n out = self._recombine_heads(out)\n out = self.out_proj(out)\n\n return out", "chunk_type": "class", "name": "RoPEAttention", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\blocks.py", "start_line": 405, "end_line": 485, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": "Implements rotary position encoding for attention mechanisms in transformer architectures.\n\nThis class extends the base Attention class by incorporating Rotary Position Encoding (RoPE) to enhance\nthe positional awareness of the attention mechanism.\n\nAttributes:\n compute_cis (Callable): Function to compute axial complex numbers for rotary encoding.\n freqs_cis (torch.Tensor): Precomputed frequency tensor for rotary encoding.\n rope_k_repeat (bool): Flag to repeat query RoPE to match key length for cross-attention to memories.\n\nMethods:\n forward: Applies rotary position encoding and computes attention between query, key, and value tensors.\n\nExamples:\n >>> rope_attn = RoPEAttention(embedding_dim=256, num_heads=8, rope_theta=10000.0, feat_sizes=(32, 32))\n >>> q = torch.randn(1, 1024, 256)\n >>> k = torch.randn(1, 1024, 256)\n >>> v = torch.randn(1, 1024, 256)\n >>> output = rope_attn(q, k, v)\n >>> print(output.shape)\n torch.Size([1, 1024, 256])", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "copy", "math", "functools.partial", "typing.Any", "typing.Optional", "typing.Tuple", "typing.Type", "typing.Union", "numpy", "torch", "torch.nn.functional", "torch.Tensor", "torch.nn", "ultralytics.nn.modules.MLP", "ultralytics.nn.modules.LayerNorm2d", "ultralytics.nn.modules.MLPBlock", "transformer.Attention", "transformer.TwoWayAttentionBlock", "transformer.TwoWayTransformer", "utils.add_decomposed_rel_pos", "utils.apply_rotary_enc", "utils.compute_axial_cis", "utils.window_partition", "utils.window_unpartition", "Attention" ], "chunk_id": "class_RoPEAttention_c89f6ba9" }, { "content": "def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor:\n \"\"\"Apply pooling and optional normalization to a tensor, handling spatial dimension permutations.\"\"\"\n if pool is None:\n return x\n # (B, H, W, C) -> (B, C, H, W)\n x = x.permute(0, 3, 1, 2)\n x = pool(x)\n # (B, C, H', W') -> (B, H', W', C)\n x = x.permute(0, 2, 3, 1)\n if norm:\n x = norm(x)\n\n return x", "chunk_type": "function", "name": "do_pool", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\blocks.py", "start_line": 488, "end_line": 500, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": "Apply pooling and optional normalization to a tensor, handling spatial dimension permutations.", "parameters": [ "x: torch.Tensor", "pool: nn.Module", "norm: nn.Module" ], "return_type": "torch.Tensor", "decorators": [], "complexity_score": 3, "dependencies": [ "copy", "math", "functools.partial", "typing.Any", "typing.Optional", "typing.Tuple", "typing.Type", "typing.Union", "numpy", "torch", "torch.nn.functional", "torch.Tensor", "torch.nn", "ultralytics.nn.modules.MLP", "ultralytics.nn.modules.LayerNorm2d", "ultralytics.nn.modules.MLPBlock", "transformer.Attention", "transformer.TwoWayAttentionBlock", "transformer.TwoWayTransformer", "utils.add_decomposed_rel_pos", "utils.apply_rotary_enc", "utils.compute_axial_cis", "utils.window_partition", "utils.window_unpartition" ], "chunk_id": "function_do_pool_669ad8d8" }, { "content": "class MultiScaleAttention(nn.Module):\n \"\"\"\n Implements multiscale self-attention with optional query pooling for efficient feature extraction.\n\n This class provides a flexible implementation of multiscale attention, allowing for optional\n downsampling of query features through pooling. It's designed to enhance the model's ability to\n capture multiscale information in visual tasks.\n\n Attributes:\n dim (int): Input dimension of the feature map.\n dim_out (int): Output dimension of the attention module.\n num_heads (int): Number of attention heads.\n scale (float): Scaling factor for dot-product attention.\n q_pool (nn.Module | None): Optional pooling module for query features.\n qkv (nn.Linear): Linear projection for query, key, and value.\n proj (nn.Linear): Output projection.\n\n Methods:\n forward: Applies multiscale attention to the input tensor.\n\n Examples:\n >>> import torch\n >>> from torch import nn\n >>> x = torch.randn(1, 64, 64, 256)\n >>> msa = MultiScaleAttention(dim=256, dim_out=256, num_heads=8)\n >>> output = msa(x)\n >>> print(output.shape)\n torch.Size([1, 64, 64, 256])\n \"\"\"\n\n def __init__(\n self,\n dim: int,\n dim_out: int,\n num_heads: int,\n q_pool: nn.Module = None,\n ):\n \"\"\"Initialize multiscale attention with optional query pooling for efficient feature extraction.\"\"\"\n super().__init__()\n\n self.dim = dim\n self.dim_out = dim_out\n\n self.num_heads = num_heads\n head_dim = dim_out // num_heads\n self.scale = head_dim**-0.5\n\n self.q_pool = q_pool\n self.qkv = nn.Linear(dim, dim_out * 3)\n self.proj = nn.Linear(dim_out, dim_out)\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Apply multiscale attention with optional query pooling to extract multiscale features.\"\"\"\n B, H, W, _ = x.shape\n # qkv with shape (B, H * W, 3, nHead, C)\n qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1)\n # q, k, v with shape (B, H * W, nheads, C)\n q, k, v = torch.unbind(qkv, 2)\n\n # Q pooling (for downsample at stage changes)\n if self.q_pool:\n q = do_pool(q.reshape(B, H, W, -1), self.q_pool)\n H, W = q.shape[1:3] # downsampled shape\n q = q.reshape(B, H * W, self.num_heads, -1)\n\n # Torch's SDPA expects [B, nheads, H*W, C] so we transpose\n x = F.scaled_dot_product_attention(\n q.transpose(1, 2),\n k.transpose(1, 2),\n v.transpose(1, 2),\n )\n # Transpose back\n x = x.transpose(1, 2)\n x = x.reshape(B, H, W, -1)\n\n x = self.proj(x)\n\n return x", "chunk_type": "class", "name": "MultiScaleAttention", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\blocks.py", "start_line": 503, "end_line": 580, "start_col": 0, "end_col": 16, "parent_name": null, "docstring": "Implements multiscale self-attention with optional query pooling for efficient feature extraction.\n\nThis class provides a flexible implementation of multiscale attention, allowing for optional\ndownsampling of query features through pooling. It's designed to enhance the model's ability to\ncapture multiscale information in visual tasks.\n\nAttributes:\n dim (int): Input dimension of the feature map.\n dim_out (int): Output dimension of the attention module.\n num_heads (int): Number of attention heads.\n scale (float): Scaling factor for dot-product attention.\n q_pool (nn.Module | None): Optional pooling module for query features.\n qkv (nn.Linear): Linear projection for query, key, and value.\n proj (nn.Linear): Output projection.\n\nMethods:\n forward: Applies multiscale attention to the input tensor.\n\nExamples:\n >>> import torch\n >>> from torch import nn\n >>> x = torch.randn(1, 64, 64, 256)\n >>> msa = MultiScaleAttention(dim=256, dim_out=256, num_heads=8)\n >>> output = msa(x)\n >>> print(output.shape)\n torch.Size([1, 64, 64, 256])", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "copy", "math", "functools.partial", "typing.Any", "typing.Optional", "typing.Tuple", "typing.Type", "typing.Union", "numpy", "torch", "torch.nn.functional", "torch.Tensor", "torch.nn", "ultralytics.nn.modules.MLP", "ultralytics.nn.modules.LayerNorm2d", "ultralytics.nn.modules.MLPBlock", "transformer.Attention", "transformer.TwoWayAttentionBlock", "transformer.TwoWayTransformer", "utils.add_decomposed_rel_pos", "utils.apply_rotary_enc", "utils.compute_axial_cis", "utils.window_partition", "utils.window_unpartition", "nn.Module" ], "chunk_id": "class_MultiScaleAttention_4d077d4a" }, { "content": "class MultiScaleBlock(nn.Module):\n \"\"\"\n A multiscale attention block with window partitioning and query pooling for efficient vision transformers.\n\n This class implements a multiscale attention mechanism with optional window partitioning and downsampling,\n designed for use in vision transformer architectures.\n\n Attributes:\n dim (int): Input dimension of the block.\n dim_out (int): Output dimension of the block.\n norm1 (nn.Module): First normalization layer.\n window_size (int): Size of the window for partitioning.\n pool (nn.Module | None): Pooling layer for query downsampling.\n q_stride (Tuple[int, int] | None): Stride for query pooling.\n attn (MultiScaleAttention): Multi-scale attention module.\n drop_path (nn.Module): Drop path layer for regularization.\n norm2 (nn.Module): Second normalization layer.\n mlp (MLP): Multi-layer perceptron module.\n proj (nn.Linear | None): Projection layer for dimension mismatch.\n\n Methods:\n forward: Processes input tensor through the multiscale block.\n\n Examples:\n >>> block = MultiScaleBlock(dim=256, dim_out=512, num_heads=8, window_size=7)\n >>> x = torch.randn(1, 56, 56, 256)\n >>> output = block(x)\n >>> print(output.shape)\n torch.Size([1, 28, 28, 512])\n \"\"\"\n\n def __init__(\n self,\n dim: int,\n dim_out: int,\n num_heads: int,\n mlp_ratio: float = 4.0,\n drop_path: float = 0.0,\n norm_layer: Union[nn.Module, str] = \"LayerNorm\",\n q_stride: Tuple[int, int] = None,\n act_layer: Type[nn.Module] = nn.GELU,\n window_size: int = 0,\n ):\n \"\"\"Initialize a multiscale attention block with window partitioning and optional query pooling.\"\"\"\n super().__init__()\n\n if isinstance(norm_layer, str):\n norm_layer = partial(getattr(nn, norm_layer), eps=1e-6)\n\n self.dim = dim\n self.dim_out = dim_out\n self.norm1 = norm_layer(dim)\n\n self.window_size = window_size\n\n self.pool, self.q_stride = None, q_stride\n if self.q_stride:\n self.pool = nn.MaxPool2d(kernel_size=q_stride, stride=q_stride, ceil_mode=False)\n\n self.attn = MultiScaleAttention(\n dim,\n dim_out,\n num_heads=num_heads,\n q_pool=self.pool,\n )\n self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()\n\n self.norm2 = norm_layer(dim_out)\n self.mlp = MLP(\n dim_out,\n int(dim_out * mlp_ratio),\n dim_out,\n num_layers=2,\n act=act_layer,\n )\n\n if dim != dim_out:\n self.proj = nn.Linear(dim, dim_out)\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Process input through multiscale attention and MLP, with optional windowing and downsampling.\"\"\"\n shortcut = x # B, H, W, C\n x = self.norm1(x)\n\n # Skip connection\n if self.dim != self.dim_out:\n shortcut = do_pool(self.proj(x), self.pool)\n\n # Window partition\n window_size = self.window_size\n if window_size > 0:\n H, W = x.shape[1], x.shape[2]\n x, pad_hw = window_partition(x, window_size)\n\n # Window Attention + Q Pooling (if stage change)\n x = self.attn(x)\n if self.q_stride:\n # Shapes have changed due to Q pooling\n window_size = self.window_size // self.q_stride[0]\n H, W = shortcut.shape[1:3]\n\n pad_h = (window_size - H % window_size) % window_size\n pad_w = (window_size - W % window_size) % window_size\n pad_hw = (H + pad_h, W + pad_w)\n\n # Reverse window partition\n if self.window_size > 0:\n x = window_unpartition(x, window_size, pad_hw, (H, W))\n\n x = shortcut + self.drop_path(x)\n # MLP\n x = x + self.drop_path(self.mlp(self.norm2(x)))\n return x", "chunk_type": "class", "name": "MultiScaleBlock", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\blocks.py", "start_line": 583, "end_line": 695, "start_col": 0, "end_col": 16, "parent_name": null, "docstring": "A multiscale attention block with window partitioning and query pooling for efficient vision transformers.\n\nThis class implements a multiscale attention mechanism with optional window partitioning and downsampling,\ndesigned for use in vision transformer architectures.\n\nAttributes:\n dim (int): Input dimension of the block.\n dim_out (int): Output dimension of the block.\n norm1 (nn.Module): First normalization layer.\n window_size (int): Size of the window for partitioning.\n pool (nn.Module | None): Pooling layer for query downsampling.\n q_stride (Tuple[int, int] | None): Stride for query pooling.\n attn (MultiScaleAttention): Multi-scale attention module.\n drop_path (nn.Module): Drop path layer for regularization.\n norm2 (nn.Module): Second normalization layer.\n mlp (MLP): Multi-layer perceptron module.\n proj (nn.Linear | None): Projection layer for dimension mismatch.\n\nMethods:\n forward: Processes input tensor through the multiscale block.\n\nExamples:\n >>> block = MultiScaleBlock(dim=256, dim_out=512, num_heads=8, window_size=7)\n >>> x = torch.randn(1, 56, 56, 256)\n >>> output = block(x)\n >>> print(output.shape)\n torch.Size([1, 28, 28, 512])", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "copy", "math", "functools.partial", "typing.Any", "typing.Optional", "typing.Tuple", "typing.Type", "typing.Union", "numpy", "torch", "torch.nn.functional", "torch.Tensor", "torch.nn", "ultralytics.nn.modules.MLP", "ultralytics.nn.modules.LayerNorm2d", "ultralytics.nn.modules.MLPBlock", "transformer.Attention", "transformer.TwoWayAttentionBlock", "transformer.TwoWayTransformer", "utils.add_decomposed_rel_pos", "utils.apply_rotary_enc", "utils.compute_axial_cis", "utils.window_partition", "utils.window_unpartition", "nn.Module" ], "chunk_id": "class_MultiScaleBlock_cea50200" }, { "content": "class PositionEmbeddingSine(nn.Module):\n \"\"\"\n A module for generating sinusoidal positional embeddings for 2D inputs like images.\n\n This class implements sinusoidal position encoding for 2D spatial positions, which can be used in\n transformer-based models for computer vision tasks.\n\n Attributes:\n num_pos_feats (int): Number of positional features (half of the embedding dimension).\n temperature (int): Temperature parameter for the sinusoidal functions.\n normalize (bool): Whether to normalize the positional embeddings.\n scale (float): Scaling factor for the embeddings when normalize is True.\n cache (dict): Cache for storing precomputed embeddings.\n\n Methods:\n _encode_xy: Encodes 2D positions using sine and cosine functions.\n encode_boxes: Encodes box coordinates and dimensions into positional embeddings.\n encode_points: Encodes 2D point coordinates with sinusoidal positional embeddings.\n forward: Generates sinusoidal position embeddings for 2D inputs.\n\n Examples:\n >>> pos_emb = PositionEmbeddingSine(num_pos_feats=128)\n >>> x = torch.randn(1, 3, 224, 224)\n >>> embeddings = pos_emb(x)\n >>> print(embeddings.shape)\n torch.Size([1, 256, 224, 224])\n \"\"\"\n\n def __init__(\n self,\n num_pos_feats: int,\n temperature: int = 10000,\n normalize: bool = True,\n scale: Optional[float] = None,\n ):\n \"\"\"Initialize sinusoidal position embeddings for 2D image inputs.\"\"\"\n super().__init__()\n assert num_pos_feats % 2 == 0, \"Expecting even model width\"\n self.num_pos_feats = num_pos_feats // 2\n self.temperature = temperature\n self.normalize = normalize\n if scale is not None and not normalize:\n raise ValueError(\"normalize should be True if scale is passed\")\n if scale is None:\n scale = 2 * math.pi\n self.scale = scale\n\n self.cache = {}\n\n def _encode_xy(self, x: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:\n \"\"\"Encode 2D positions using sine/cosine functions for transformer positional embeddings.\"\"\"\n assert len(x) == len(y) and x.ndim == y.ndim == 1\n x_embed = x * self.scale\n y_embed = y * self.scale\n\n dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)\n dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)\n\n pos_x = x_embed[:, None] / dim_t\n pos_y = y_embed[:, None] / dim_t\n pos_x = torch.stack((pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2).flatten(1)\n pos_y = torch.stack((pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2).flatten(1)\n return pos_x, pos_y\n\n @torch.no_grad()\n def encode_boxes(self, x: torch.Tensor, y: torch.Tensor, w: torch.Tensor, h: torch.Tensor) -> torch.Tensor:\n \"\"\"Encode box coordinates and dimensions into positional embeddings for detection.\"\"\"\n pos_x, pos_y = self._encode_xy(x, y)\n return torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1)\n\n encode = encode_boxes # Backwards compatibility\n\n @torch.no_grad()\n def encode_points(self, x: torch.Tensor, y: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:\n \"\"\"Encode 2D points with sinusoidal embeddings and append labels.\"\"\"\n (bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape\n assert bx == by and nx == ny and bx == bl and nx == nl\n pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten())\n pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1)\n return torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2)\n\n @torch.no_grad()\n def forward(self, x: torch.Tensor) -> Tensor:\n \"\"\"Generate sinusoidal position embeddings for 2D inputs like images.\"\"\"\n cache_key = (x.shape[-2], x.shape[-1])\n if cache_key in self.cache:\n return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1)\n y_embed = (\n torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device)\n .view(1, -1, 1)\n .repeat(x.shape[0], 1, x.shape[-1])\n )\n x_embed = (\n torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device)\n .view(1, 1, -1)\n .repeat(x.shape[0], x.shape[-2], 1)\n )\n\n if self.normalize:\n eps = 1e-6\n y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale\n x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale\n\n dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)\n dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)\n\n pos_x = x_embed[:, :, :, None] / dim_t\n pos_y = y_embed[:, :, :, None] / dim_t\n pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)\n pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)\n pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)\n self.cache[cache_key] = pos[0]\n return pos", "chunk_type": "class", "name": "PositionEmbeddingSine", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\blocks.py", "start_line": 698, "end_line": 810, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": "A module for generating sinusoidal positional embeddings for 2D inputs like images.\n\nThis class implements sinusoidal position encoding for 2D spatial positions, which can be used in\ntransformer-based models for computer vision tasks.\n\nAttributes:\n num_pos_feats (int): Number of positional features (half of the embedding dimension).\n temperature (int): Temperature parameter for the sinusoidal functions.\n normalize (bool): Whether to normalize the positional embeddings.\n scale (float): Scaling factor for the embeddings when normalize is True.\n cache (dict): Cache for storing precomputed embeddings.\n\nMethods:\n _encode_xy: Encodes 2D positions using sine and cosine functions.\n encode_boxes: Encodes box coordinates and dimensions into positional embeddings.\n encode_points: Encodes 2D point coordinates with sinusoidal positional embeddings.\n forward: Generates sinusoidal position embeddings for 2D inputs.\n\nExamples:\n >>> pos_emb = PositionEmbeddingSine(num_pos_feats=128)\n >>> x = torch.randn(1, 3, 224, 224)\n >>> embeddings = pos_emb(x)\n >>> print(embeddings.shape)\n torch.Size([1, 256, 224, 224])", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "copy", "math", "functools.partial", "typing.Any", "typing.Optional", "typing.Tuple", "typing.Type", "typing.Union", "numpy", "torch", "torch.nn.functional", "torch.Tensor", "torch.nn", "ultralytics.nn.modules.MLP", "ultralytics.nn.modules.LayerNorm2d", "ultralytics.nn.modules.MLPBlock", "transformer.Attention", "transformer.TwoWayAttentionBlock", "transformer.TwoWayTransformer", "utils.add_decomposed_rel_pos", "utils.apply_rotary_enc", "utils.compute_axial_cis", "utils.window_partition", "utils.window_unpartition", "nn.Module" ], "chunk_id": "class_PositionEmbeddingSine_cb0e1827" }, { "content": "class PositionEmbeddingRandom(nn.Module):\n \"\"\"\n Positional encoding using random spatial frequencies.\n\n This class generates positional embeddings for input coordinates using random spatial frequencies. It is\n particularly useful for transformer-based models that require position information.\n\n Attributes:\n positional_encoding_gaussian_matrix (torch.Tensor): A buffer containing random values for encoding.\n\n Methods:\n _pe_encoding: Positionally encodes points that are normalized to [0,1].\n forward: Generates positional encoding for a grid of the specified size.\n forward_with_coords: Positionally encodes points that are not normalized to [0,1].\n\n Examples:\n >>> pe = PositionEmbeddingRandom(num_pos_feats=64)\n >>> size = (32, 32)\n >>> encoding = pe(size)\n >>> print(encoding.shape)\n torch.Size([128, 32, 32])\n \"\"\"\n\n def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:\n \"\"\"Initialize random spatial frequency position embedding for transformers.\"\"\"\n super().__init__()\n if scale is None or scale <= 0.0:\n scale = 1.0\n self.register_buffer(\"positional_encoding_gaussian_matrix\", scale * torch.randn((2, num_pos_feats)))\n\n # Set non-deterministic for forward() error 'cumsum_cuda_kernel does not have a deterministic implementation'\n torch.use_deterministic_algorithms(False)\n torch.backends.cudnn.deterministic = False\n\n def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:\n \"\"\"Encode normalized [0,1] coordinates using random spatial frequencies.\"\"\"\n # Assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape\n coords = 2 * coords - 1\n coords = coords @ self.positional_encoding_gaussian_matrix\n coords = 2 * np.pi * coords\n # Outputs d_1 x ... x d_n x C shape\n return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)\n\n def forward(self, size: Tuple[int, int]) -> torch.Tensor:\n \"\"\"Generate positional encoding for a grid using random spatial frequencies.\"\"\"\n h, w = size\n device: Any = self.positional_encoding_gaussian_matrix.device\n grid = torch.ones((h, w), device=device, dtype=torch.float32)\n y_embed = grid.cumsum(dim=0) - 0.5\n x_embed = grid.cumsum(dim=1) - 0.5\n y_embed = y_embed / h\n x_embed = x_embed / w\n\n pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))\n return pe.permute(2, 0, 1) # C x H x W\n\n def forward_with_coords(self, coords_input: torch.Tensor, image_size: Tuple[int, int]) -> torch.Tensor:\n \"\"\"Positionally encode input coordinates, normalizing them to [0,1] based on the given image size.\"\"\"\n coords = coords_input.clone()\n coords[:, :, 0] = coords[:, :, 0] / image_size[1]\n coords[:, :, 1] = coords[:, :, 1] / image_size[0]\n return self._pe_encoding(coords.to(torch.float)) # B x N x C", "chunk_type": "class", "name": "PositionEmbeddingRandom", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\blocks.py", "start_line": 813, "end_line": 874, "start_col": 0, "end_col": 56, "parent_name": null, "docstring": "Positional encoding using random spatial frequencies.\n\nThis class generates positional embeddings for input coordinates using random spatial frequencies. It is\nparticularly useful for transformer-based models that require position information.\n\nAttributes:\n positional_encoding_gaussian_matrix (torch.Tensor): A buffer containing random values for encoding.\n\nMethods:\n _pe_encoding: Positionally encodes points that are normalized to [0,1].\n forward: Generates positional encoding for a grid of the specified size.\n forward_with_coords: Positionally encodes points that are not normalized to [0,1].\n\nExamples:\n >>> pe = PositionEmbeddingRandom(num_pos_feats=64)\n >>> size = (32, 32)\n >>> encoding = pe(size)\n >>> print(encoding.shape)\n torch.Size([128, 32, 32])", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "copy", "math", "functools.partial", "typing.Any", "typing.Optional", "typing.Tuple", "typing.Type", "typing.Union", "numpy", "torch", "torch.nn.functional", "torch.Tensor", "torch.nn", "ultralytics.nn.modules.MLP", "ultralytics.nn.modules.LayerNorm2d", "ultralytics.nn.modules.MLPBlock", "transformer.Attention", "transformer.TwoWayAttentionBlock", "transformer.TwoWayTransformer", "utils.add_decomposed_rel_pos", "utils.apply_rotary_enc", "utils.compute_axial_cis", "utils.window_partition", "utils.window_unpartition", "nn.Module" ], "chunk_id": "class_PositionEmbeddingRandom_b8cb3304" }, { "content": "class Block(nn.Module):\n \"\"\"\n Transformer block with support for window attention and residual propagation.\n\n This class implements a transformer block that can use either global or windowed self-attention,\n followed by a feed-forward network. It supports relative positional embeddings and is designed\n for use in vision transformer architectures.\n\n Attributes:\n norm1 (nn.Module): First normalization layer.\n attn (REAttention): Self-attention layer with optional relative positional encoding.\n norm2 (nn.Module): Second normalization layer.\n mlp (MLPBlock): Multi-layer perceptron block.\n window_size (int): Size of attention window. If 0, global attention is used.\n\n Methods:\n forward: Processes input through the transformer block.\n\n Examples:\n >>> import torch\n >>> block = Block(dim=256, num_heads=8, window_size=7)\n >>> x = torch.randn(1, 56, 56, 256)\n >>> output = block(x)\n >>> print(output.shape)\n torch.Size([1, 56, 56, 256])\n \"\"\"\n\n def __init__(\n self,\n dim: int,\n num_heads: int,\n mlp_ratio: float = 4.0,\n qkv_bias: bool = True,\n norm_layer: Type[nn.Module] = nn.LayerNorm,\n act_layer: Type[nn.Module] = nn.GELU,\n use_rel_pos: bool = False,\n rel_pos_zero_init: bool = True,\n window_size: int = 0,\n input_size: Optional[Tuple[int, int]] = None,\n ) -> None:\n \"\"\"\n Initialize a transformer block with optional window attention and relative positional embeddings.\n\n This constructor sets up a transformer block that can use either global or windowed self-attention,\n followed by a feed-forward network. It supports relative positional embeddings and is designed\n for use in vision transformer architectures.\n\n Args:\n dim (int): Number of input channels.\n num_heads (int): Number of attention heads in the self-attention layer.\n mlp_ratio (float): Ratio of mlp hidden dimension to embedding dimension.\n qkv_bias (bool): If True, adds a learnable bias to query, key, value projections.\n norm_layer (Type[nn.Module]): Type of normalization layer to use.\n act_layer (Type[nn.Module]): Type of activation function to use in the MLP block.\n use_rel_pos (bool): If True, uses relative positional embeddings in attention.\n rel_pos_zero_init (bool): If True, initializes relative positional parameters to zero.\n window_size (int): Size of attention window. If 0, uses global attention.\n input_size (Tuple[int, int] | None): Input resolution for calculating relative positional parameter size.\n\n Examples:\n >>> block = Block(dim=256, num_heads=8, window_size=7)\n >>> x = torch.randn(1, 56, 56, 256)\n >>> output = block(x)\n >>> print(output.shape)\n torch.Size([1, 56, 56, 256])\n \"\"\"\n super().__init__()\n self.norm1 = norm_layer(dim)\n self.attn = REAttention(\n dim,\n num_heads=num_heads,\n qkv_bias=qkv_bias,\n use_rel_pos=use_rel_pos,\n rel_pos_zero_init=rel_pos_zero_init,\n input_size=input_size if window_size == 0 else (window_size, window_size),\n )\n\n self.norm2 = norm_layer(dim)\n self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)\n\n self.window_size = window_size\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Process input through transformer block with optional windowed self-attention and residual connection.\"\"\"\n shortcut = x\n x = self.norm1(x)\n # Window partition\n if self.window_size > 0:\n H, W = x.shape[1], x.shape[2]\n x, pad_hw = window_partition(x, self.window_size)\n\n x = self.attn(x)\n # Reverse window partition\n if self.window_size > 0:\n x = window_unpartition(x, self.window_size, pad_hw, (H, W))\n\n x = shortcut + x\n return x + self.mlp(self.norm2(x))", "chunk_type": "class", "name": "Block", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\blocks.py", "start_line": 877, "end_line": 974, "start_col": 0, "end_col": 42, "parent_name": null, "docstring": "Transformer block with support for window attention and residual propagation.\n\nThis class implements a transformer block that can use either global or windowed self-attention,\nfollowed by a feed-forward network. It supports relative positional embeddings and is designed\nfor use in vision transformer architectures.\n\nAttributes:\n norm1 (nn.Module): First normalization layer.\n attn (REAttention): Self-attention layer with optional relative positional encoding.\n norm2 (nn.Module): Second normalization layer.\n mlp (MLPBlock): Multi-layer perceptron block.\n window_size (int): Size of attention window. If 0, global attention is used.\n\nMethods:\n forward: Processes input through the transformer block.\n\nExamples:\n >>> import torch\n >>> block = Block(dim=256, num_heads=8, window_size=7)\n >>> x = torch.randn(1, 56, 56, 256)\n >>> output = block(x)\n >>> print(output.shape)\n torch.Size([1, 56, 56, 256])", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "copy", "math", "functools.partial", "typing.Any", "typing.Optional", "typing.Tuple", "typing.Type", "typing.Union", "numpy", "torch", "torch.nn.functional", "torch.Tensor", "torch.nn", "ultralytics.nn.modules.MLP", "ultralytics.nn.modules.LayerNorm2d", "ultralytics.nn.modules.MLPBlock", "transformer.Attention", "transformer.TwoWayAttentionBlock", "transformer.TwoWayTransformer", "utils.add_decomposed_rel_pos", "utils.apply_rotary_enc", "utils.compute_axial_cis", "utils.window_partition", "utils.window_unpartition", "nn.Module" ], "chunk_id": "class_Block_3dc742de" }, { "content": "class REAttention(nn.Module):\n \"\"\"\n Relative Position Attention module for efficient self-attention in transformer architectures.\n\n This class implements a multi-head attention mechanism with relative positional embeddings, designed\n for use in vision transformer models. It supports optional query pooling and window partitioning\n for efficient processing of large inputs.\n\n Attributes:\n num_heads (int): Number of attention heads.\n scale (float): Scaling factor for attention computation.\n qkv (nn.Linear): Linear projection for query, key, and value.\n proj (nn.Linear): Output projection layer.\n use_rel_pos (bool): Whether to use relative positional embeddings.\n rel_pos_h (nn.Parameter): Relative positional embeddings for height dimension.\n rel_pos_w (nn.Parameter): Relative positional embeddings for width dimension.\n\n Methods:\n forward: Applies multi-head attention with optional relative positional encoding to input tensor.\n\n Examples:\n >>> attention = REAttention(dim=256, num_heads=8, input_size=(32, 32))\n >>> x = torch.randn(1, 32, 32, 256)\n >>> output = attention(x)\n >>> print(output.shape)\n torch.Size([1, 32, 32, 256])\n \"\"\"\n\n def __init__(\n self,\n dim: int,\n num_heads: int = 8,\n qkv_bias: bool = True,\n use_rel_pos: bool = False,\n rel_pos_zero_init: bool = True,\n input_size: Optional[Tuple[int, int]] = None,\n ) -> None:\n \"\"\"\n Initialize a Relative Position Attention module for transformer-based architectures.\n\n This module implements multi-head attention with optional relative positional encodings, designed\n specifically for vision tasks in transformer models.\n\n Args:\n dim (int): Number of input channels.\n num_heads (int): Number of attention heads.\n qkv_bias (bool): If True, adds a learnable bias to query, key, value projections.\n use_rel_pos (bool): If True, uses relative positional encodings.\n rel_pos_zero_init (bool): If True, initializes relative positional parameters to zero.\n input_size (Tuple[int, int] | None): Input resolution for calculating relative positional parameter size.\n Required if use_rel_pos is True.\n\n Examples:\n >>> attention = REAttention(dim=256, num_heads=8, input_size=(32, 32))\n >>> x = torch.randn(1, 32, 32, 256)\n >>> output = attention(x)\n >>> print(output.shape)\n torch.Size([1, 32, 32, 256])\n \"\"\"\n super().__init__()\n self.num_heads = num_heads\n head_dim = dim // num_heads\n self.scale = head_dim**-0.5\n\n self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n self.proj = nn.Linear(dim, dim)\n\n self.use_rel_pos = use_rel_pos\n if self.use_rel_pos:\n assert input_size is not None, \"Input size must be provided if using relative positional encoding.\"\n # Initialize relative positional embeddings\n self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))\n self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Apply multi-head attention with optional relative positional encoding to input tensor.\"\"\"\n B, H, W, _ = x.shape\n # qkv with shape (3, B, nHead, H * W, C)\n qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)\n # q, k, v with shape (B * nHead, H * W, C)\n q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)\n\n attn = (q * self.scale) @ k.transpose(-2, -1)\n\n if self.use_rel_pos:\n attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))\n\n attn = attn.softmax(dim=-1)\n x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)\n return self.proj(x)", "chunk_type": "class", "name": "REAttention", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\blocks.py", "start_line": 977, "end_line": 1066, "start_col": 0, "end_col": 27, "parent_name": null, "docstring": "Relative Position Attention module for efficient self-attention in transformer architectures.\n\nThis class implements a multi-head attention mechanism with relative positional embeddings, designed\nfor use in vision transformer models. It supports optional query pooling and window partitioning\nfor efficient processing of large inputs.\n\nAttributes:\n num_heads (int): Number of attention heads.\n scale (float): Scaling factor for attention computation.\n qkv (nn.Linear): Linear projection for query, key, and value.\n proj (nn.Linear): Output projection layer.\n use_rel_pos (bool): Whether to use relative positional embeddings.\n rel_pos_h (nn.Parameter): Relative positional embeddings for height dimension.\n rel_pos_w (nn.Parameter): Relative positional embeddings for width dimension.\n\nMethods:\n forward: Applies multi-head attention with optional relative positional encoding to input tensor.\n\nExamples:\n >>> attention = REAttention(dim=256, num_heads=8, input_size=(32, 32))\n >>> x = torch.randn(1, 32, 32, 256)\n >>> output = attention(x)\n >>> print(output.shape)\n torch.Size([1, 32, 32, 256])", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "copy", "math", "functools.partial", "typing.Any", "typing.Optional", "typing.Tuple", "typing.Type", "typing.Union", "numpy", "torch", "torch.nn.functional", "torch.Tensor", "torch.nn", "ultralytics.nn.modules.MLP", "ultralytics.nn.modules.LayerNorm2d", "ultralytics.nn.modules.MLPBlock", "transformer.Attention", "transformer.TwoWayAttentionBlock", "transformer.TwoWayTransformer", "utils.add_decomposed_rel_pos", "utils.apply_rotary_enc", "utils.compute_axial_cis", "utils.window_partition", "utils.window_unpartition", "nn.Module" ], "chunk_id": "class_REAttention_f78613db" }, { "content": "class PatchEmbed(nn.Module):\n \"\"\"\n Image to Patch Embedding module for vision transformer architectures.\n\n This module converts an input image into a sequence of patch embeddings using a convolutional layer.\n It is commonly used as the first layer in vision transformer architectures to transform image data\n into a suitable format for subsequent transformer blocks.\n\n Attributes:\n proj (nn.Conv2d): Convolutional layer for projecting image patches to embeddings.\n\n Methods:\n forward: Applies patch embedding to the input tensor.\n\n Examples:\n >>> patch_embed = PatchEmbed(kernel_size=(16, 16), stride=(16, 16), in_chans=3, embed_dim=768)\n >>> x = torch.randn(1, 3, 224, 224)\n >>> output = patch_embed(x)\n >>> print(output.shape)\n torch.Size([1, 768, 14, 14])\n \"\"\"\n\n def __init__(\n self,\n kernel_size: Tuple[int, int] = (16, 16),\n stride: Tuple[int, int] = (16, 16),\n padding: Tuple[int, int] = (0, 0),\n in_chans: int = 3,\n embed_dim: int = 768,\n ) -> None:\n \"\"\"\n Initialize the PatchEmbed module for converting image patches to embeddings.\n\n This module is typically used as the first layer in vision transformer architectures to transform\n image data into a suitable format for subsequent transformer blocks.\n\n Args:\n kernel_size (Tuple[int, int]): Size of the convolutional kernel for patch extraction.\n stride (Tuple[int, int]): Stride of the convolutional operation.\n padding (Tuple[int, int]): Padding applied to the input before convolution.\n in_chans (int): Number of input image channels.\n embed_dim (int): Dimensionality of the output patch embeddings.\n\n Examples:\n >>> patch_embed = PatchEmbed(kernel_size=(16, 16), stride=(16, 16), in_chans=3, embed_dim=768)\n >>> x = torch.randn(1, 3, 224, 224)\n >>> output = patch_embed(x)\n >>> print(output.shape)\n torch.Size([1, 768, 14, 14])\n \"\"\"\n super().__init__()\n\n self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Compute patch embedding by applying convolution and transposing resulting tensor.\"\"\"\n return self.proj(x).permute(0, 2, 3, 1) # B C H W -> B H W C", "chunk_type": "class", "name": "PatchEmbed", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\blocks.py", "start_line": 1069, "end_line": 1125, "start_col": 0, "end_col": 47, "parent_name": null, "docstring": "Image to Patch Embedding module for vision transformer architectures.\n\nThis module converts an input image into a sequence of patch embeddings using a convolutional layer.\nIt is commonly used as the first layer in vision transformer architectures to transform image data\ninto a suitable format for subsequent transformer blocks.\n\nAttributes:\n proj (nn.Conv2d): Convolutional layer for projecting image patches to embeddings.\n\nMethods:\n forward: Applies patch embedding to the input tensor.\n\nExamples:\n >>> patch_embed = PatchEmbed(kernel_size=(16, 16), stride=(16, 16), in_chans=3, embed_dim=768)\n >>> x = torch.randn(1, 3, 224, 224)\n >>> output = patch_embed(x)\n >>> print(output.shape)\n torch.Size([1, 768, 14, 14])", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "copy", "math", "functools.partial", "typing.Any", "typing.Optional", "typing.Tuple", "typing.Type", "typing.Union", "numpy", "torch", "torch.nn.functional", "torch.Tensor", "torch.nn", "ultralytics.nn.modules.MLP", "ultralytics.nn.modules.LayerNorm2d", "ultralytics.nn.modules.MLPBlock", "transformer.Attention", "transformer.TwoWayAttentionBlock", "transformer.TwoWayTransformer", "utils.add_decomposed_rel_pos", "utils.apply_rotary_enc", "utils.compute_axial_cis", "utils.window_partition", "utils.window_unpartition", "nn.Module" ], "chunk_id": "class_PatchEmbed_102cc953" }, { "content": "from typing import List, Optional, Tuple, Type", "chunk_type": "import", "name": "List, Optional, Tuple, Type", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\decoders.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 46, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_List, Optional, Tuple, Type_d508b8c2" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\decoders.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_5dca3f97" }, { "content": "from torch import nn", "chunk_type": "import", "name": "nn", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\decoders.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 20, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_nn_7328c221" }, { "content": "from ultralytics.nn.modules import MLP, LayerNorm2d", "chunk_type": "import", "name": "MLP, LayerNorm2d", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\decoders.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 51, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_MLP, LayerNorm2d_c51d52d2" }, { "content": "class MaskDecoder(nn.Module):\n \"\"\"\n Decoder module for generating masks and their associated quality scores using a transformer architecture.\n\n This class predicts masks given image and prompt embeddings, utilizing a transformer to process the inputs and\n generate mask predictions along with their quality scores.\n\n Attributes:\n transformer_dim (int): Channel dimension for the transformer module.\n transformer (nn.Module): Transformer module used for mask prediction.\n num_multimask_outputs (int): Number of masks to predict for disambiguating masks.\n iou_token (nn.Embedding): Embedding for the IoU token.\n num_mask_tokens (int): Number of mask tokens.\n mask_tokens (nn.Embedding): Embedding for the mask tokens.\n output_upscaling (nn.Sequential): Neural network sequence for upscaling the output.\n output_hypernetworks_mlps (nn.ModuleList): Hypernetwork MLPs for generating masks.\n iou_prediction_head (nn.Module): MLP for predicting mask quality.\n\n Methods:\n forward: Predict masks given image and prompt embeddings.\n predict_masks: Internal method for mask prediction.\n\n Examples:\n >>> decoder = MaskDecoder(transformer_dim=256, transformer=transformer_module)\n >>> masks, iou_pred = decoder(\n ... image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, multimask_output=True\n ... )\n >>> print(f\"Predicted masks shape: {masks.shape}, IoU predictions shape: {iou_pred.shape}\")\n \"\"\"\n\n def __init__(\n self,\n transformer_dim: int,\n transformer: nn.Module,\n num_multimask_outputs: int = 3,\n activation: Type[nn.Module] = nn.GELU,\n iou_head_depth: int = 3,\n iou_head_hidden_dim: int = 256,\n ) -> None:\n \"\"\"\n Initialize the MaskDecoder module for generating masks and their associated quality scores.\n\n Args:\n transformer_dim (int): Channel dimension for the transformer module.\n transformer (nn.Module): Transformer module used for mask prediction.\n num_multimask_outputs (int): Number of masks to predict for disambiguating masks.\n activation (Type[nn.Module]): Type of activation to use when upscaling masks.\n iou_head_depth (int): Depth of the MLP used to predict mask quality.\n iou_head_hidden_dim (int): Hidden dimension of the MLP used to predict mask quality.\n\n Examples:\n >>> transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=256, nhead=8), num_layers=6)\n >>> decoder = MaskDecoder(transformer_dim=256, transformer=transformer)\n >>> print(decoder)\n \"\"\"\n super().__init__()\n self.transformer_dim = transformer_dim\n self.transformer = transformer\n\n self.num_multimask_outputs = num_multimask_outputs\n\n self.iou_token = nn.Embedding(1, transformer_dim)\n self.num_mask_tokens = num_multimask_outputs + 1\n self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)\n\n self.output_upscaling = nn.Sequential(\n nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),\n LayerNorm2d(transformer_dim // 4),\n activation(),\n nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),\n activation(),\n )\n self.output_hypernetworks_mlps = nn.ModuleList(\n [MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for _ in range(self.num_mask_tokens)]\n )\n\n self.iou_prediction_head = MLP(transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth)\n\n def forward(\n self,\n image_embeddings: torch.Tensor,\n image_pe: torch.Tensor,\n sparse_prompt_embeddings: torch.Tensor,\n dense_prompt_embeddings: torch.Tensor,\n multimask_output: bool,\n ) -> Tuple[torch.Tensor, torch.Tensor]:\n \"\"\"\n Predict masks given image and prompt embeddings.\n\n Args:\n image_embeddings (torch.Tensor): Embeddings from the image encoder.\n image_pe (torch.Tensor): Positional encoding with the shape of image_embeddings.\n sparse_prompt_embeddings (torch.Tensor): Embeddings of the points and boxes.\n dense_prompt_embeddings (torch.Tensor): Embeddings of the mask inputs.\n multimask_output (bool): Whether to return multiple masks or a single mask.\n\n Returns:\n masks (torch.Tensor): Batched predicted masks.\n iou_pred (torch.Tensor): Batched predictions of mask quality.\n\n Examples:\n >>> decoder = MaskDecoder(transformer_dim=256, transformer=transformer_module)\n >>> image_emb = torch.rand(1, 256, 64, 64)\n >>> image_pe = torch.rand(1, 256, 64, 64)\n >>> sparse_emb = torch.rand(1, 2, 256)\n >>> dense_emb = torch.rand(1, 256, 64, 64)\n >>> masks, iou_pred = decoder(image_emb, image_pe, sparse_emb, dense_emb, multimask_output=True)\n >>> print(f\"Masks shape: {masks.shape}, IoU predictions shape: {iou_pred.shape}\")\n \"\"\"\n masks, iou_pred = self.predict_masks(\n image_embeddings=image_embeddings,\n image_pe=image_pe,\n sparse_prompt_embeddings=sparse_prompt_embeddings,\n dense_prompt_embeddings=dense_prompt_embeddings,\n )\n\n # Select the correct mask or masks for output\n mask_slice = slice(1, None) if multimask_output else slice(0, 1)\n masks = masks[:, mask_slice, :, :]\n iou_pred = iou_pred[:, mask_slice]\n\n return masks, iou_pred\n\n def predict_masks(\n self,\n image_embeddings: torch.Tensor,\n image_pe: torch.Tensor,\n sparse_prompt_embeddings: torch.Tensor,\n dense_prompt_embeddings: torch.Tensor,\n ) -> Tuple[torch.Tensor, torch.Tensor]:\n \"\"\"Predict masks and quality scores using image and prompt embeddings via transformer architecture.\"\"\"\n # Concatenate output tokens\n output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)\n output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.shape[0], -1, -1)\n tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)\n\n # Expand per-image data in batch direction to be per-mask\n src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)\n src = src + dense_prompt_embeddings\n pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)\n b, c, h, w = src.shape\n\n # Run the transformer\n hs, src = self.transformer(src, pos_src, tokens)\n iou_token_out = hs[:, 0, :]\n mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]\n\n # Upscale mask embeddings and predict masks using the mask tokens\n src = src.transpose(1, 2).view(b, c, h, w)\n upscaled_embedding = self.output_upscaling(src)\n hyper_in_list: List[torch.Tensor] = [\n self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) for i in range(self.num_mask_tokens)\n ]\n hyper_in = torch.stack(hyper_in_list, dim=1)\n b, c, h, w = upscaled_embedding.shape\n masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)\n\n # Generate mask quality predictions\n iou_pred = self.iou_prediction_head(iou_token_out)\n\n return masks, iou_pred", "chunk_type": "class", "name": "MaskDecoder", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\decoders.py", "start_line": 11, "end_line": 171, "start_col": 0, "end_col": 30, "parent_name": null, "docstring": "Decoder module for generating masks and their associated quality scores using a transformer architecture.\n\nThis class predicts masks given image and prompt embeddings, utilizing a transformer to process the inputs and\ngenerate mask predictions along with their quality scores.\n\nAttributes:\n transformer_dim (int): Channel dimension for the transformer module.\n transformer (nn.Module): Transformer module used for mask prediction.\n num_multimask_outputs (int): Number of masks to predict for disambiguating masks.\n iou_token (nn.Embedding): Embedding for the IoU token.\n num_mask_tokens (int): Number of mask tokens.\n mask_tokens (nn.Embedding): Embedding for the mask tokens.\n output_upscaling (nn.Sequential): Neural network sequence for upscaling the output.\n output_hypernetworks_mlps (nn.ModuleList): Hypernetwork MLPs for generating masks.\n iou_prediction_head (nn.Module): MLP for predicting mask quality.\n\nMethods:\n forward: Predict masks given image and prompt embeddings.\n predict_masks: Internal method for mask prediction.\n\nExamples:\n >>> decoder = MaskDecoder(transformer_dim=256, transformer=transformer_module)\n >>> masks, iou_pred = decoder(\n ... image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, multimask_output=True\n ... )\n >>> print(f\"Predicted masks shape: {masks.shape}, IoU predictions shape: {iou_pred.shape}\")", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.List", "typing.Optional", "typing.Tuple", "typing.Type", "torch", "torch.nn", "ultralytics.nn.modules.MLP", "ultralytics.nn.modules.LayerNorm2d", "nn.Module" ], "chunk_id": "class_MaskDecoder_43de16dc" }, { "content": "class SAM2MaskDecoder(nn.Module):\n \"\"\"\n Transformer-based decoder for predicting instance segmentation masks from image and prompt embeddings.\n\n This class extends the functionality of the MaskDecoder, incorporating additional features such as\n high-resolution feature processing, dynamic multimask output, and object score prediction.\n\n Attributes:\n transformer_dim (int): Channel dimension of the transformer.\n transformer (nn.Module): Transformer used to predict masks.\n num_multimask_outputs (int): Number of masks to predict when disambiguating masks.\n iou_token (nn.Embedding): Embedding for IOU token.\n num_mask_tokens (int): Total number of mask tokens.\n mask_tokens (nn.Embedding): Embedding for mask tokens.\n pred_obj_scores (bool): Whether to predict object scores.\n obj_score_token (nn.Embedding): Embedding for object score token.\n use_multimask_token_for_obj_ptr (bool): Whether to use multimask token for object pointer.\n output_upscaling (nn.Sequential): Upscaling layers for output.\n use_high_res_features (bool): Whether to use high-resolution features.\n conv_s0 (nn.Conv2d): Convolutional layer for high-resolution features (s0).\n conv_s1 (nn.Conv2d): Convolutional layer for high-resolution features (s1).\n output_hypernetworks_mlps (nn.ModuleList): List of MLPs for output hypernetworks.\n iou_prediction_head (MLP): MLP for IOU prediction.\n pred_obj_score_head (nn.Linear | MLP): Linear layer or MLP for object score prediction.\n dynamic_multimask_via_stability (bool): Whether to use dynamic multimask via stability.\n dynamic_multimask_stability_delta (float): Delta value for dynamic multimask stability.\n dynamic_multimask_stability_thresh (float): Threshold for dynamic multimask stability.\n\n Methods:\n forward: Predict masks given image and prompt embeddings.\n predict_masks: Predict instance segmentation masks from image and prompt embeddings.\n _get_stability_scores: Compute mask stability scores based on IoU between thresholds.\n _dynamic_multimask_via_stability: Dynamically select the most stable mask output.\n\n Examples:\n >>> image_embeddings = torch.rand(1, 256, 64, 64)\n >>> image_pe = torch.rand(1, 256, 64, 64)\n >>> sparse_prompt_embeddings = torch.rand(1, 2, 256)\n >>> dense_prompt_embeddings = torch.rand(1, 256, 64, 64)\n >>> decoder = SAM2MaskDecoder(256, transformer)\n >>> masks, iou_pred, sam_tokens_out, obj_score_logits = decoder.forward(\n ... image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, True, False\n ... )\n \"\"\"\n\n def __init__(\n self,\n transformer_dim: int,\n transformer: nn.Module,\n num_multimask_outputs: int = 3,\n activation: Type[nn.Module] = nn.GELU,\n iou_head_depth: int = 3,\n iou_head_hidden_dim: int = 256,\n use_high_res_features: bool = False,\n iou_prediction_use_sigmoid=False,\n dynamic_multimask_via_stability=False,\n dynamic_multimask_stability_delta=0.05,\n dynamic_multimask_stability_thresh=0.98,\n pred_obj_scores: bool = False,\n pred_obj_scores_mlp: bool = False,\n use_multimask_token_for_obj_ptr: bool = False,\n ) -> None:\n \"\"\"\n Initialize the SAM2MaskDecoder module for predicting instance segmentation masks.\n\n This decoder extends the functionality of MaskDecoder, incorporating additional features such as\n high-resolution feature processing, dynamic multimask output, and object score prediction.\n\n Args:\n transformer_dim (int): Channel dimension of the transformer.\n transformer (nn.Module): Transformer used to predict masks.\n num_multimask_outputs (int): Number of masks to predict when disambiguating masks.\n activation (Type[nn.Module]): Type of activation to use when upscaling masks.\n iou_head_depth (int): Depth of the MLP used to predict mask quality.\n iou_head_hidden_dim (int): Hidden dimension of the MLP used to predict mask quality.\n use_high_res_features (bool): Whether to use high-resolution features.\n iou_prediction_use_sigmoid (bool): Whether to use sigmoid for IOU prediction.\n dynamic_multimask_via_stability (bool): Whether to use dynamic multimask via stability.\n dynamic_multimask_stability_delta (float): Delta value for dynamic multimask stability.\n dynamic_multimask_stability_thresh (float): Threshold for dynamic multimask stability.\n pred_obj_scores (bool): Whether to predict object scores.\n pred_obj_scores_mlp (bool): Whether to use MLP for object score prediction.\n use_multimask_token_for_obj_ptr (bool): Whether to use multimask token for object pointer.\n\n Examples:\n >>> transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=256, nhead=8), num_layers=6)\n >>> decoder = SAM2MaskDecoder(transformer_dim=256, transformer=transformer)\n >>> print(decoder)\n \"\"\"\n super().__init__()\n self.transformer_dim = transformer_dim\n self.transformer = transformer\n\n self.num_multimask_outputs = num_multimask_outputs\n\n self.iou_token = nn.Embedding(1, transformer_dim)\n self.num_mask_tokens = num_multimask_outputs + 1\n self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)\n\n self.pred_obj_scores = pred_obj_scores\n if self.pred_obj_scores:\n self.obj_score_token = nn.Embedding(1, transformer_dim)\n self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr\n\n self.output_upscaling = nn.Sequential(\n nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),\n LayerNorm2d(transformer_dim // 4),\n activation(),\n nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),\n activation(),\n )\n self.use_high_res_features = use_high_res_features\n if use_high_res_features:\n self.conv_s0 = nn.Conv2d(transformer_dim, transformer_dim // 8, kernel_size=1, stride=1)\n self.conv_s1 = nn.Conv2d(transformer_dim, transformer_dim // 4, kernel_size=1, stride=1)\n\n self.output_hypernetworks_mlps = nn.ModuleList(\n [MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for _ in range(self.num_mask_tokens)]\n )\n\n self.iou_prediction_head = MLP(\n transformer_dim,\n iou_head_hidden_dim,\n self.num_mask_tokens,\n iou_head_depth,\n sigmoid=iou_prediction_use_sigmoid,\n )\n if self.pred_obj_scores:\n self.pred_obj_score_head = nn.Linear(transformer_dim, 1)\n if pred_obj_scores_mlp:\n self.pred_obj_score_head = MLP(transformer_dim, transformer_dim, 1, 3)\n\n # When outputting a single mask, optionally we can dynamically fall back to the best\n # multimask output token if the single mask output token gives low stability scores.\n self.dynamic_multimask_via_stability = dynamic_multimask_via_stability\n self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta\n self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh\n\n def forward(\n self,\n image_embeddings: torch.Tensor,\n image_pe: torch.Tensor,\n sparse_prompt_embeddings: torch.Tensor,\n dense_prompt_embeddings: torch.Tensor,\n multimask_output: bool,\n repeat_image: bool,\n high_res_features: Optional[List[torch.Tensor]] = None,\n ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n \"\"\"\n Predict masks given image and prompt embeddings.\n\n Args:\n image_embeddings (torch.Tensor): Embeddings from the image encoder with shape (B, C, H, W).\n image_pe (torch.Tensor): Positional encoding with the shape of image_embeddings (B, C, H, W).\n sparse_prompt_embeddings (torch.Tensor): Embeddings of the points and boxes with shape (B, N, C).\n dense_prompt_embeddings (torch.Tensor): Embeddings of the mask inputs with shape (B, C, H, W).\n multimask_output (bool): Whether to return multiple masks or a single mask.\n repeat_image (bool): Flag to repeat the image embeddings.\n high_res_features (List[torch.Tensor] | None, optional): Optional high-resolution features.\n\n Returns:\n masks (torch.Tensor): Batched predicted masks with shape (B, N, H, W).\n iou_pred (torch.Tensor): Batched predictions of mask quality with shape (B, N).\n sam_tokens_out (torch.Tensor): Batched SAM token for mask output with shape (B, N, C).\n object_score_logits (torch.Tensor): Batched object score logits with shape (B, 1).\n\n Examples:\n >>> image_embeddings = torch.rand(1, 256, 64, 64)\n >>> image_pe = torch.rand(1, 256, 64, 64)\n >>> sparse_prompt_embeddings = torch.rand(1, 2, 256)\n >>> dense_prompt_embeddings = torch.rand(1, 256, 64, 64)\n >>> decoder = SAM2MaskDecoder(256, transformer)\n >>> masks, iou_pred, sam_tokens_out, obj_score_logits = decoder.forward(\n ... image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, True, False\n ... )\n \"\"\"\n masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks(\n image_embeddings=image_embeddings,\n image_pe=image_pe,\n sparse_prompt_embeddings=sparse_prompt_embeddings,\n dense_prompt_embeddings=dense_prompt_embeddings,\n repeat_image=repeat_image,\n high_res_features=high_res_features,\n )\n\n # Select the correct mask or masks for output\n if multimask_output:\n masks = masks[:, 1:, :, :]\n iou_pred = iou_pred[:, 1:]\n elif self.dynamic_multimask_via_stability and not self.training:\n masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred)\n else:\n masks = masks[:, 0:1, :, :]\n iou_pred = iou_pred[:, 0:1]\n\n if multimask_output and self.use_multimask_token_for_obj_ptr:\n sam_tokens_out = mask_tokens_out[:, 1:] # [b, 3, c] shape\n else:\n # Take the mask output token. Here we *always* use the token for single mask output.\n # At test time, even if we track after 1-click (and using multimask_output=True),\n # we still take the single mask token here. The rationale is that we always track\n # after multiple clicks during training, so the past tokens seen during training\n # are always the single mask token (and we'll let it be the object-memory token).\n sam_tokens_out = mask_tokens_out[:, 0:1] # [b, 1, c] shape\n\n return masks, iou_pred, sam_tokens_out, object_score_logits\n\n def predict_masks(\n self,\n image_embeddings: torch.Tensor,\n image_pe: torch.Tensor,\n sparse_prompt_embeddings: torch.Tensor,\n dense_prompt_embeddings: torch.Tensor,\n repeat_image: bool,\n high_res_features: Optional[List[torch.Tensor]] = None,\n ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n \"\"\"Predict instance segmentation masks from image and prompt embeddings using a transformer.\"\"\"\n # Concatenate output tokens\n s = 0\n if self.pred_obj_scores:\n output_tokens = torch.cat(\n [\n self.obj_score_token.weight,\n self.iou_token.weight,\n self.mask_tokens.weight,\n ],\n dim=0,\n )\n s = 1\n else:\n output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)\n output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)\n tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)\n\n # Expand per-image data in batch direction to be per-mask\n if repeat_image:\n src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)\n else:\n assert image_embeddings.shape[0] == tokens.shape[0]\n src = image_embeddings\n src = src + dense_prompt_embeddings\n assert image_pe.size(0) == 1, \"image_pe should have size 1 in batch dim (from `get_dense_pe()`)\"\n pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)\n b, c, h, w = src.shape\n\n # Run the transformer\n hs, src = self.transformer(src, pos_src, tokens)\n iou_token_out = hs[:, s, :]\n mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :]\n\n # Upscale mask embeddings and predict masks using the mask tokens\n src = src.transpose(1, 2).view(b, c, h, w)\n if not self.use_high_res_features:\n upscaled_embedding = self.output_upscaling(src)\n else:\n dc1, ln1, act1, dc2, act2 = self.output_upscaling\n feat_s0, feat_s1 = high_res_features\n upscaled_embedding = act1(ln1(dc1(src) + feat_s1))\n upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0)\n\n hyper_in_list: List[torch.Tensor] = [\n self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) for i in range(self.num_mask_tokens)\n ]\n hyper_in = torch.stack(hyper_in_list, dim=1)\n b, c, h, w = upscaled_embedding.shape\n masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)\n\n # Generate mask quality predictions\n iou_pred = self.iou_prediction_head(iou_token_out)\n if self.pred_obj_scores:\n assert s == 1\n object_score_logits = self.pred_obj_score_head(hs[:, 0, :])\n else:\n # Obj scores logits - default to 10.0, i.e. assuming the object is present, sigmoid(10)=1\n object_score_logits = 10.0 * iou_pred.new_ones(iou_pred.shape[0], 1)\n\n return masks, iou_pred, mask_tokens_out, object_score_logits\n\n def _get_stability_scores(self, mask_logits):\n \"\"\"Compute mask stability scores based on IoU between upper and lower thresholds.\"\"\"\n mask_logits = mask_logits.flatten(-2)\n stability_delta = self.dynamic_multimask_stability_delta\n area_i = torch.sum(mask_logits > stability_delta, dim=-1).float()\n area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float()\n return torch.where(area_u > 0, area_i / area_u, 1.0)\n\n def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):\n \"\"\"\n Dynamically select the most stable mask output based on stability scores and IoU predictions.\n\n This method is used when outputting a single mask. If the stability score from the current single-mask\n output (based on output token 0) falls below a threshold, it instead selects from multi-mask outputs\n (based on output tokens 1-3) the mask with the highest predicted IoU score. This ensures a valid mask\n for both clicking and tracking scenarios.\n\n Args:\n all_mask_logits (torch.Tensor): Logits for all predicted masks, shape (B, N, H, W) where B is\n batch size, N is number of masks (typically 4), and H, W are mask dimensions.\n all_iou_scores (torch.Tensor): Predicted IoU scores for all masks, shape (B, N).\n\n Returns:\n mask_logits_out (torch.Tensor): Selected mask logits, shape (B, 1, H, W).\n iou_scores_out (torch.Tensor): Selected IoU scores, shape (B, 1).\n\n Examples:\n >>> decoder = SAM2MaskDecoder(...)\n >>> all_mask_logits = torch.rand(2, 4, 256, 256) # 2 images, 4 masks each\n >>> all_iou_scores = torch.rand(2, 4)\n >>> mask_logits, iou_scores = decoder._dynamic_multimask_via_stability(all_mask_logits, all_iou_scores)\n >>> print(mask_logits.shape, iou_scores.shape)\n torch.Size([2, 1, 256, 256]) torch.Size([2, 1])\n \"\"\"\n # The best mask from multimask output tokens (1~3)\n multimask_logits = all_mask_logits[:, 1:, :, :]\n multimask_iou_scores = all_iou_scores[:, 1:]\n best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1)\n batch_inds = torch.arange(multimask_iou_scores.size(0), device=all_iou_scores.device)\n best_multimask_logits = multimask_logits[batch_inds, best_scores_inds]\n best_multimask_logits = best_multimask_logits.unsqueeze(1)\n best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds]\n best_multimask_iou_scores = best_multimask_iou_scores.unsqueeze(1)\n\n # The mask from singlemask output token 0 and its stability score\n singlemask_logits = all_mask_logits[:, 0:1, :, :]\n singlemask_iou_scores = all_iou_scores[:, 0:1]\n stability_scores = self._get_stability_scores(singlemask_logits)\n is_stable = stability_scores >= self.dynamic_multimask_stability_thresh\n\n # Dynamically fall back to best multimask output upon low stability scores.\n mask_logits_out = torch.where(\n is_stable[..., None, None].expand_as(singlemask_logits),\n singlemask_logits,\n best_multimask_logits,\n )\n iou_scores_out = torch.where(\n is_stable.expand_as(singlemask_iou_scores),\n singlemask_iou_scores,\n best_multimask_iou_scores,\n )\n return mask_logits_out, iou_scores_out", "chunk_type": "class", "name": "SAM2MaskDecoder", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\decoders.py", "start_line": 174, "end_line": 513, "start_col": 0, "end_col": 46, "parent_name": null, "docstring": "Transformer-based decoder for predicting instance segmentation masks from image and prompt embeddings.\n\nThis class extends the functionality of the MaskDecoder, incorporating additional features such as\nhigh-resolution feature processing, dynamic multimask output, and object score prediction.\n\nAttributes:\n transformer_dim (int): Channel dimension of the transformer.\n transformer (nn.Module): Transformer used to predict masks.\n num_multimask_outputs (int): Number of masks to predict when disambiguating masks.\n iou_token (nn.Embedding): Embedding for IOU token.\n num_mask_tokens (int): Total number of mask tokens.\n mask_tokens (nn.Embedding): Embedding for mask tokens.\n pred_obj_scores (bool): Whether to predict object scores.\n obj_score_token (nn.Embedding): Embedding for object score token.\n use_multimask_token_for_obj_ptr (bool): Whether to use multimask token for object pointer.\n output_upscaling (nn.Sequential): Upscaling layers for output.\n use_high_res_features (bool): Whether to use high-resolution features.\n conv_s0 (nn.Conv2d): Convolutional layer for high-resolution features (s0).\n conv_s1 (nn.Conv2d): Convolutional layer for high-resolution features (s1).\n output_hypernetworks_mlps (nn.ModuleList): List of MLPs for output hypernetworks.\n iou_prediction_head (MLP): MLP for IOU prediction.\n pred_obj_score_head (nn.Linear | MLP): Linear layer or MLP for object score prediction.\n dynamic_multimask_via_stability (bool): Whether to use dynamic multimask via stability.\n dynamic_multimask_stability_delta (float): Delta value for dynamic multimask stability.\n dynamic_multimask_stability_thresh (float): Threshold for dynamic multimask stability.\n\nMethods:\n forward: Predict masks given image and prompt embeddings.\n predict_masks: Predict instance segmentation masks from image and prompt embeddings.\n _get_stability_scores: Compute mask stability scores based on IoU between thresholds.\n _dynamic_multimask_via_stability: Dynamically select the most stable mask output.\n\nExamples:\n >>> image_embeddings = torch.rand(1, 256, 64, 64)\n >>> image_pe = torch.rand(1, 256, 64, 64)\n >>> sparse_prompt_embeddings = torch.rand(1, 2, 256)\n >>> dense_prompt_embeddings = torch.rand(1, 256, 64, 64)\n >>> decoder = SAM2MaskDecoder(256, transformer)\n >>> masks, iou_pred, sam_tokens_out, obj_score_logits = decoder.forward(\n ... image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, True, False\n ... )", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.List", "typing.Optional", "typing.Tuple", "typing.Type", "torch", "torch.nn", "ultralytics.nn.modules.MLP", "ultralytics.nn.modules.LayerNorm2d", "nn.Module" ], "chunk_id": "class_SAM2MaskDecoder_5d57bda5" }, { "content": "from typing import List, Optional, Tuple, Type", "chunk_type": "import", "name": "List, Optional, Tuple, Type", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\encoders.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 46, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_List, Optional, Tuple, Type_54fc0454" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\encoders.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_c5b44249" }, { "content": "import torch.nn as nn", "chunk_type": "import", "name": "torch.nn", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\encoders.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 21, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch.nn_08fde6ac" }, { "content": "import torch.nn.functional as F", "chunk_type": "import", "name": "torch.nn.functional", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\encoders.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 31, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch.nn.functional_da6be143" }, { "content": "from ultralytics.nn.modules import LayerNorm2d", "chunk_type": "import", "name": "LayerNorm2d", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\encoders.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 46, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_LayerNorm2d_68f0168c" }, { "content": "from .blocks import (\n Block,\n CXBlock,\n Fuser,\n MaskDownSampler,\n MultiScaleBlock,\n PatchEmbed,\n PositionEmbeddingRandom,\n PositionEmbeddingSine,\n)", "chunk_type": "import", "name": "Block, CXBlock, Fuser, MaskDownSampler, MultiScaleBlock, PatchEmbed, PositionEmbeddingRandom, PositionEmbeddingSine", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\encoders.py", "start_line": 11, "end_line": 20, "start_col": 0, "end_col": 1, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Block, CXBlock, Fuser, MaskDownSampler, MultiScaleBlock, PatchEmbed, PositionEmbeddingRandom, PositionEmbeddingSine_7e4f3b8e" }, { "content": "class ImageEncoderViT(nn.Module):\n \"\"\"\n An image encoder using Vision Transformer (ViT) architecture for encoding images into a compact latent space.\n\n This class processes images by splitting them into patches, applying transformer blocks, and generating a final\n encoded representation through a neck module.\n\n Attributes:\n img_size (int): Dimension of input images, assumed to be square.\n patch_embed (PatchEmbed): Module for patch embedding.\n pos_embed (nn.Parameter | None): Absolute positional embedding for patches.\n blocks (nn.ModuleList): List of transformer blocks for processing patch embeddings.\n neck (nn.Sequential): Neck module to further process the output.\n\n Methods:\n forward: Process input through patch embedding, positional embedding, blocks, and neck.\n\n Examples:\n >>> import torch\n >>> encoder = ImageEncoderViT(img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12)\n >>> input_image = torch.randn(1, 3, 224, 224)\n >>> output = encoder(input_image)\n >>> print(output.shape)\n \"\"\"\n\n def __init__(\n self,\n img_size: int = 1024,\n patch_size: int = 16,\n in_chans: int = 3,\n embed_dim: int = 768,\n depth: int = 12,\n num_heads: int = 12,\n mlp_ratio: float = 4.0,\n out_chans: int = 256,\n qkv_bias: bool = True,\n norm_layer: Type[nn.Module] = nn.LayerNorm,\n act_layer: Type[nn.Module] = nn.GELU,\n use_abs_pos: bool = True,\n use_rel_pos: bool = False,\n rel_pos_zero_init: bool = True,\n window_size: int = 0,\n global_attn_indexes: Tuple[int, ...] = (),\n ) -> None:\n \"\"\"\n Initialize an ImageEncoderViT instance for encoding images using Vision Transformer architecture.\n\n Args:\n img_size (int): Input image size, assumed to be square.\n patch_size (int): Size of image patches.\n in_chans (int): Number of input image channels.\n embed_dim (int): Dimension of patch embeddings.\n depth (int): Number of transformer blocks.\n num_heads (int): Number of attention heads in each block.\n mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension.\n out_chans (int): Number of output channels from the neck module.\n qkv_bias (bool): If True, adds learnable bias to query, key, value projections.\n norm_layer (Type[nn.Module]): Type of normalization layer to use.\n act_layer (Type[nn.Module]): Type of activation layer to use.\n use_abs_pos (bool): If True, uses absolute positional embeddings.\n use_rel_pos (bool): If True, adds relative positional embeddings to attention maps.\n rel_pos_zero_init (bool): If True, initializes relative positional parameters to zero.\n window_size (int): Size of attention window for windowed attention blocks.\n global_attn_indexes (Tuple[int, ...]): Indices of blocks that use global attention.\n\n Examples:\n >>> encoder = ImageEncoderViT(img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12)\n >>> input_image = torch.randn(1, 3, 224, 224)\n >>> output = encoder(input_image)\n >>> print(output.shape)\n \"\"\"\n super().__init__()\n self.img_size = img_size\n\n self.patch_embed = PatchEmbed(\n kernel_size=(patch_size, patch_size),\n stride=(patch_size, patch_size),\n in_chans=in_chans,\n embed_dim=embed_dim,\n )\n\n self.pos_embed: Optional[nn.Parameter] = None\n if use_abs_pos:\n # Initialize absolute positional embedding with pretrain image size\n self.pos_embed = nn.Parameter(torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim))\n\n self.blocks = nn.ModuleList()\n for i in range(depth):\n block = Block(\n dim=embed_dim,\n num_heads=num_heads,\n mlp_ratio=mlp_ratio,\n qkv_bias=qkv_bias,\n norm_layer=norm_layer,\n act_layer=act_layer,\n use_rel_pos=use_rel_pos,\n rel_pos_zero_init=rel_pos_zero_init,\n window_size=window_size if i not in global_attn_indexes else 0,\n input_size=(img_size // patch_size, img_size // patch_size),\n )\n self.blocks.append(block)\n\n self.neck = nn.Sequential(\n nn.Conv2d(\n embed_dim,\n out_chans,\n kernel_size=1,\n bias=False,\n ),\n LayerNorm2d(out_chans),\n nn.Conv2d(\n out_chans,\n out_chans,\n kernel_size=3,\n padding=1,\n bias=False,\n ),\n LayerNorm2d(out_chans),\n )\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Process input through patch embedding, positional embedding, transformer blocks, and neck module.\"\"\"\n x = self.patch_embed(x)\n if self.pos_embed is not None:\n pos_embed = (\n F.interpolate(self.pos_embed.permute(0, 3, 1, 2), scale_factor=self.img_size / 1024).permute(0, 2, 3, 1)\n if self.img_size != 1024\n else self.pos_embed\n )\n x = x + pos_embed\n for blk in self.blocks:\n x = blk(x)\n return self.neck(x.permute(0, 3, 1, 2))", "chunk_type": "class", "name": "ImageEncoderViT", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\encoders.py", "start_line": 23, "end_line": 155, "start_col": 0, "end_col": 47, "parent_name": null, "docstring": "An image encoder using Vision Transformer (ViT) architecture for encoding images into a compact latent space.\n\nThis class processes images by splitting them into patches, applying transformer blocks, and generating a final\nencoded representation through a neck module.\n\nAttributes:\n img_size (int): Dimension of input images, assumed to be square.\n patch_embed (PatchEmbed): Module for patch embedding.\n pos_embed (nn.Parameter | None): Absolute positional embedding for patches.\n blocks (nn.ModuleList): List of transformer blocks for processing patch embeddings.\n neck (nn.Sequential): Neck module to further process the output.\n\nMethods:\n forward: Process input through patch embedding, positional embedding, blocks, and neck.\n\nExamples:\n >>> import torch\n >>> encoder = ImageEncoderViT(img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12)\n >>> input_image = torch.randn(1, 3, 224, 224)\n >>> output = encoder(input_image)\n >>> print(output.shape)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.List", "typing.Optional", "typing.Tuple", "typing.Type", "torch", "torch.nn", "torch.nn.functional", "ultralytics.nn.modules.LayerNorm2d", "blocks.Block", "blocks.CXBlock", "blocks.Fuser", "blocks.MaskDownSampler", "blocks.MultiScaleBlock", "blocks.PatchEmbed", "blocks.PositionEmbeddingRandom", "blocks.PositionEmbeddingSine", "nn.Module" ], "chunk_id": "class_ImageEncoderViT_137a3367" }, { "content": "class PromptEncoder(nn.Module):\n \"\"\"\n Encode different types of prompts for input to SAM's mask decoder, producing sparse and dense embeddings.\n\n Attributes:\n embed_dim (int): Dimension of the embeddings.\n input_image_size (Tuple[int, int]): Size of the input image as (H, W).\n image_embedding_size (Tuple[int, int]): Spatial size of the image embedding as (H, W).\n pe_layer (PositionEmbeddingRandom): Module for random position embedding.\n num_point_embeddings (int): Number of point embeddings for different types of points.\n point_embeddings (nn.ModuleList): List of point embeddings.\n not_a_point_embed (nn.Embedding): Embedding for points that are not part of any label.\n mask_input_size (Tuple[int, int]): Size of the input mask.\n mask_downscaling (nn.Sequential): Neural network for downscaling the mask.\n no_mask_embed (nn.Embedding): Embedding for cases where no mask is provided.\n\n Methods:\n get_dense_pe: Return the positional encoding used to encode point prompts.\n forward: Embed different types of prompts, returning both sparse and dense embeddings.\n\n Examples:\n >>> prompt_encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16)\n >>> points = (torch.rand(1, 5, 2), torch.randint(0, 4, (1, 5)))\n >>> boxes = torch.rand(1, 2, 2)\n >>> masks = torch.rand(1, 1, 256, 256)\n >>> sparse_embeddings, dense_embeddings = prompt_encoder(points, boxes, masks)\n >>> print(sparse_embeddings.shape, dense_embeddings.shape)\n torch.Size([1, 7, 256]) torch.Size([1, 256, 64, 64])\n \"\"\"\n\n def __init__(\n self,\n embed_dim: int,\n image_embedding_size: Tuple[int, int],\n input_image_size: Tuple[int, int],\n mask_in_chans: int,\n activation: Type[nn.Module] = nn.GELU,\n ) -> None:\n \"\"\"\n Initialize the PromptEncoder module for encoding various types of prompts.\n\n Args:\n embed_dim (int): The dimension of the embeddings.\n image_embedding_size (Tuple[int, int]): The spatial size of the image embedding as (H, W).\n input_image_size (Tuple[int, int]): The padded size of the input image as (H, W).\n mask_in_chans (int): The number of hidden channels used for encoding input masks.\n activation (Type[nn.Module]): The activation function to use when encoding input masks.\n\n Examples:\n >>> prompt_encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16)\n >>> points = (torch.rand(1, 5, 2), torch.randint(0, 4, (1, 5)))\n >>> boxes = torch.rand(1, 2, 2)\n >>> masks = torch.rand(1, 1, 256, 256)\n >>> sparse_embeddings, dense_embeddings = prompt_encoder(points, boxes, masks)\n >>> print(sparse_embeddings.shape, dense_embeddings.shape)\n torch.Size([1, 7, 256]) torch.Size([1, 256, 64, 64])\n \"\"\"\n super().__init__()\n self.embed_dim = embed_dim\n self.input_image_size = input_image_size\n self.image_embedding_size = image_embedding_size\n self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)\n\n self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners\n point_embeddings = [nn.Embedding(1, embed_dim) for _ in range(self.num_point_embeddings)]\n self.point_embeddings = nn.ModuleList(point_embeddings)\n self.not_a_point_embed = nn.Embedding(1, embed_dim)\n\n self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1])\n self.mask_downscaling = nn.Sequential(\n nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),\n LayerNorm2d(mask_in_chans // 4),\n activation(),\n nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),\n LayerNorm2d(mask_in_chans),\n activation(),\n nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),\n )\n self.no_mask_embed = nn.Embedding(1, embed_dim)\n\n def get_dense_pe(self) -> torch.Tensor:\n \"\"\"\n Return the dense positional encoding used for encoding point prompts.\n\n Generate a positional encoding for a dense set of points matching the shape of the image\n encoding. The encoding is used to provide spatial information to the model when processing point prompts.\n\n Returns:\n (torch.Tensor): Positional encoding tensor with shape (1, embed_dim, H, W), where H and W are the\n height and width of the image embedding size, respectively.\n\n Examples:\n >>> prompt_encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16)\n >>> dense_pe = prompt_encoder.get_dense_pe()\n >>> print(dense_pe.shape)\n torch.Size([1, 256, 64, 64])\n \"\"\"\n return self.pe_layer(self.image_embedding_size).unsqueeze(0)\n\n def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool) -> torch.Tensor:\n \"\"\"Embed point prompts by applying positional encoding and label-specific embeddings.\"\"\"\n points = points + 0.5 # Shift to center of pixel\n if pad:\n padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)\n padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)\n points = torch.cat([points, padding_point], dim=1)\n labels = torch.cat([labels, padding_label], dim=1)\n point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)\n point_embedding[labels == -1] = 0.0\n point_embedding[labels == -1] += self.not_a_point_embed.weight\n point_embedding[labels == 0] += self.point_embeddings[0].weight\n point_embedding[labels == 1] += self.point_embeddings[1].weight\n point_embedding[labels == 2] += self.point_embeddings[2].weight\n point_embedding[labels == 3] += self.point_embeddings[3].weight\n return point_embedding\n\n def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:\n \"\"\"Embed box prompts by applying positional encoding and adding corner embeddings.\"\"\"\n boxes = boxes + 0.5 # Shift to center of pixel\n coords = boxes.reshape(-1, 2, 2)\n corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)\n corner_embedding[:, 0, :] += self.point_embeddings[2].weight\n corner_embedding[:, 1, :] += self.point_embeddings[3].weight\n return corner_embedding\n\n def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:\n \"\"\"Embed mask inputs by downscaling and processing through convolutional layers.\"\"\"\n return self.mask_downscaling(masks)\n\n @staticmethod\n def _get_batch_size(\n points: Optional[Tuple[torch.Tensor, torch.Tensor]],\n boxes: Optional[torch.Tensor],\n masks: Optional[torch.Tensor],\n ) -> int:\n \"\"\"Get the batch size of the output given the batch size of the input prompts.\"\"\"\n if points is not None:\n return points[0].shape[0]\n elif boxes is not None:\n return boxes.shape[0]\n elif masks is not None:\n return masks.shape[0]\n else:\n return 1\n\n def _get_device(self) -> torch.device:\n \"\"\"Return the device of the first point embedding's weight tensor.\"\"\"\n return self.point_embeddings[0].weight.device\n\n def forward(\n self,\n points: Optional[Tuple[torch.Tensor, torch.Tensor]],\n boxes: Optional[torch.Tensor],\n masks: Optional[torch.Tensor],\n ) -> Tuple[torch.Tensor, torch.Tensor]:\n \"\"\"\n Embed different types of prompts, returning both sparse and dense embeddings.\n\n Args:\n points (Tuple[torch.Tensor, torch.Tensor] | None): Point coordinates and labels to embed. The first\n tensor contains coordinates with shape (B, N, 2), and the second tensor contains labels with\n shape (B, N).\n boxes (torch.Tensor | None): Boxes to embed with shape (B, M, 2, 2), where M is the number of boxes.\n masks (torch.Tensor | None): Masks to embed with shape (B, 1, H, W).\n\n Returns:\n sparse_embeddings (torch.Tensor): Sparse embeddings for points and boxes with shape (B, N, embed_dim).\n dense_embeddings (torch.Tensor): Dense embeddings for masks of shape (B, embed_dim, embed_H, embed_W).\n\n Examples:\n >>> encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16)\n >>> points = (torch.rand(1, 5, 2), torch.randint(0, 4, (1, 5)))\n >>> boxes = torch.rand(1, 2, 2, 2)\n >>> masks = torch.rand(1, 1, 256, 256)\n >>> sparse_emb, dense_emb = encoder(points, boxes, masks)\n >>> print(sparse_emb.shape, dense_emb.shape)\n torch.Size([1, 7, 256]) torch.Size([1, 256, 64, 64])\n \"\"\"\n bs = self._get_batch_size(points, boxes, masks)\n sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())\n if points is not None:\n coords, labels = points\n point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))\n sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)\n if boxes is not None:\n box_embeddings = self._embed_boxes(boxes)\n sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)\n\n if masks is not None:\n dense_embeddings = self._embed_masks(masks)\n else:\n dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(\n bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]\n )\n\n return sparse_embeddings, dense_embeddings", "chunk_type": "class", "name": "PromptEncoder", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\encoders.py", "start_line": 158, "end_line": 353, "start_col": 0, "end_col": 50, "parent_name": null, "docstring": "Encode different types of prompts for input to SAM's mask decoder, producing sparse and dense embeddings.\n\nAttributes:\n embed_dim (int): Dimension of the embeddings.\n input_image_size (Tuple[int, int]): Size of the input image as (H, W).\n image_embedding_size (Tuple[int, int]): Spatial size of the image embedding as (H, W).\n pe_layer (PositionEmbeddingRandom): Module for random position embedding.\n num_point_embeddings (int): Number of point embeddings for different types of points.\n point_embeddings (nn.ModuleList): List of point embeddings.\n not_a_point_embed (nn.Embedding): Embedding for points that are not part of any label.\n mask_input_size (Tuple[int, int]): Size of the input mask.\n mask_downscaling (nn.Sequential): Neural network for downscaling the mask.\n no_mask_embed (nn.Embedding): Embedding for cases where no mask is provided.\n\nMethods:\n get_dense_pe: Return the positional encoding used to encode point prompts.\n forward: Embed different types of prompts, returning both sparse and dense embeddings.\n\nExamples:\n >>> prompt_encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16)\n >>> points = (torch.rand(1, 5, 2), torch.randint(0, 4, (1, 5)))\n >>> boxes = torch.rand(1, 2, 2)\n >>> masks = torch.rand(1, 1, 256, 256)\n >>> sparse_embeddings, dense_embeddings = prompt_encoder(points, boxes, masks)\n >>> print(sparse_embeddings.shape, dense_embeddings.shape)\n torch.Size([1, 7, 256]) torch.Size([1, 256, 64, 64])", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.List", "typing.Optional", "typing.Tuple", "typing.Type", "torch", "torch.nn", "torch.nn.functional", "ultralytics.nn.modules.LayerNorm2d", "blocks.Block", "blocks.CXBlock", "blocks.Fuser", "blocks.MaskDownSampler", "blocks.MultiScaleBlock", "blocks.PatchEmbed", "blocks.PositionEmbeddingRandom", "blocks.PositionEmbeddingSine", "nn.Module" ], "chunk_id": "class_PromptEncoder_cef65f27" }, { "content": "class MemoryEncoder(nn.Module):\n \"\"\"\n Encode pixel features and masks into a memory representation for efficient image segmentation.\n\n This class processes pixel-level features and masks, fusing them to generate encoded memory representations\n suitable for downstream tasks in image segmentation models like SAM (Segment Anything Model).\n\n Attributes:\n mask_downsampler (MaskDownSampler): Module for downsampling input masks.\n pix_feat_proj (nn.Conv2d): Convolutional layer for projecting pixel features.\n fuser (Fuser): Module for fusing pixel features and masks.\n position_encoding (PositionEmbeddingSine): Module for adding positional encoding to features.\n out_proj (nn.Module): Output projection layer, either nn.Identity or nn.Conv2d.\n\n Methods:\n forward: Process input pixel features and masks to generate encoded memory representations.\n\n Examples:\n >>> import torch\n >>> encoder = MemoryEncoder(out_dim=256, in_dim=256)\n >>> pix_feat = torch.randn(1, 256, 64, 64)\n >>> masks = torch.randn(1, 1, 64, 64)\n >>> encoded_feat, pos = encoder(pix_feat, masks)\n >>> print(encoded_feat.shape, pos.shape)\n torch.Size([1, 256, 64, 64]) torch.Size([1, 128, 64, 64])\n \"\"\"\n\n def __init__(\n self,\n out_dim,\n in_dim=256, # in_dim of pix_feats\n ):\n \"\"\"\n Initialize the MemoryEncoder for encoding pixel features and masks into memory representations.\n\n This encoder processes pixel-level features and masks, fusing them to generate encoded memory representations\n suitable for downstream tasks in image segmentation models like SAM (Segment Anything Model).\n\n Args:\n out_dim (int): Output dimension of the encoded features.\n in_dim (int): Input dimension of the pixel features.\n\n Examples:\n >>> encoder = MemoryEncoder(out_dim=256, in_dim=256)\n >>> pix_feat = torch.randn(1, 256, 64, 64)\n >>> masks = torch.randn(1, 1, 64, 64)\n >>> encoded_feat, pos = encoder(pix_feat, masks)\n >>> print(encoded_feat.shape, pos.shape)\n torch.Size([1, 256, 64, 64]) torch.Size([1, 128, 64, 64])\n \"\"\"\n super().__init__()\n\n self.mask_downsampler = MaskDownSampler(kernel_size=3, stride=2, padding=1)\n\n self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1)\n self.fuser = Fuser(CXBlock(dim=256), num_layers=2)\n self.position_encoding = PositionEmbeddingSine(num_pos_feats=64)\n self.out_proj = nn.Identity()\n if out_dim != in_dim:\n self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1)\n\n def forward(\n self,\n pix_feat: torch.Tensor,\n masks: torch.Tensor,\n skip_mask_sigmoid: bool = False,\n ) -> dict:\n \"\"\"Process pixel features and masks to generate encoded memory representations for segmentation.\"\"\"\n if not skip_mask_sigmoid:\n masks = F.sigmoid(masks)\n masks = self.mask_downsampler(masks)\n\n # Fuse pix_feats and downsampled masks, in case the visual features are on CPU, cast them to CUDA\n pix_feat = pix_feat.to(masks.device)\n\n x = self.pix_feat_proj(pix_feat)\n x = x + masks\n x = self.fuser(x)\n x = self.out_proj(x)\n\n pos = self.position_encoding(x).to(x.dtype)\n\n return {\"vision_features\": x, \"vision_pos_enc\": [pos]}", "chunk_type": "class", "name": "MemoryEncoder", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\encoders.py", "start_line": 356, "end_line": 438, "start_col": 0, "end_col": 62, "parent_name": null, "docstring": "Encode pixel features and masks into a memory representation for efficient image segmentation.\n\nThis class processes pixel-level features and masks, fusing them to generate encoded memory representations\nsuitable for downstream tasks in image segmentation models like SAM (Segment Anything Model).\n\nAttributes:\n mask_downsampler (MaskDownSampler): Module for downsampling input masks.\n pix_feat_proj (nn.Conv2d): Convolutional layer for projecting pixel features.\n fuser (Fuser): Module for fusing pixel features and masks.\n position_encoding (PositionEmbeddingSine): Module for adding positional encoding to features.\n out_proj (nn.Module): Output projection layer, either nn.Identity or nn.Conv2d.\n\nMethods:\n forward: Process input pixel features and masks to generate encoded memory representations.\n\nExamples:\n >>> import torch\n >>> encoder = MemoryEncoder(out_dim=256, in_dim=256)\n >>> pix_feat = torch.randn(1, 256, 64, 64)\n >>> masks = torch.randn(1, 1, 64, 64)\n >>> encoded_feat, pos = encoder(pix_feat, masks)\n >>> print(encoded_feat.shape, pos.shape)\n torch.Size([1, 256, 64, 64]) torch.Size([1, 128, 64, 64])", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.List", "typing.Optional", "typing.Tuple", "typing.Type", "torch", "torch.nn", "torch.nn.functional", "ultralytics.nn.modules.LayerNorm2d", "blocks.Block", "blocks.CXBlock", "blocks.Fuser", "blocks.MaskDownSampler", "blocks.MultiScaleBlock", "blocks.PatchEmbed", "blocks.PositionEmbeddingRandom", "blocks.PositionEmbeddingSine", "nn.Module" ], "chunk_id": "class_MemoryEncoder_789e12d3" }, { "content": "class ImageEncoder(nn.Module):\n \"\"\"\n Encode images using a trunk-neck architecture, producing multiscale features and positional encodings.\n\n This class combines a trunk network for feature extraction with a neck network for feature refinement\n and positional encoding generation. It can optionally discard the lowest resolution features.\n\n Attributes:\n trunk (nn.Module): The trunk network for initial feature extraction.\n neck (nn.Module): The neck network for feature refinement and positional encoding generation.\n scalp (int): Number of lowest resolution feature levels to discard.\n\n Methods:\n forward: Process the input image through the trunk and neck networks.\n\n Examples:\n >>> trunk = SomeTrunkNetwork()\n >>> neck = SomeNeckNetwork()\n >>> encoder = ImageEncoder(trunk, neck, scalp=1)\n >>> image = torch.randn(1, 3, 224, 224)\n >>> output = encoder(image)\n >>> print(output.keys())\n dict_keys(['vision_features', 'vision_pos_enc', 'backbone_fpn'])\n \"\"\"\n\n def __init__(\n self,\n trunk: nn.Module,\n neck: nn.Module,\n scalp: int = 0,\n ):\n \"\"\"\n Initialize the ImageEncoder with trunk and neck networks for feature extraction and refinement.\n\n This encoder combines a trunk network for feature extraction with a neck network for feature refinement\n and positional encoding generation. It can optionally discard the lowest resolution features.\n\n Args:\n trunk (nn.Module): The trunk network for initial feature extraction.\n neck (nn.Module): The neck network for feature refinement and positional encoding generation.\n scalp (int): Number of lowest resolution feature levels to discard.\n\n Examples:\n >>> trunk = SomeTrunkNetwork()\n >>> neck = SomeNeckNetwork()\n >>> encoder = ImageEncoder(trunk, neck, scalp=1)\n >>> image = torch.randn(1, 3, 224, 224)\n >>> output = encoder(image)\n >>> print(output.keys())\n dict_keys(['vision_features', 'vision_pos_enc', 'backbone_fpn'])\n \"\"\"\n super().__init__()\n self.trunk = trunk\n self.neck = neck\n self.scalp = scalp\n assert self.trunk.channel_list == self.neck.backbone_channel_list, (\n f\"Channel dims of trunk {self.trunk.channel_list} and neck {self.neck.backbone_channel_list} do not match.\"\n )\n\n def forward(self, sample: torch.Tensor):\n \"\"\"Encode input through trunk and neck networks, returning multiscale features and positional encodings.\"\"\"\n features, pos = self.neck(self.trunk(sample))\n if self.scalp > 0:\n # Discard the lowest resolution features\n features, pos = features[: -self.scalp], pos[: -self.scalp]\n\n src = features[-1]\n return {\n \"vision_features\": src,\n \"vision_pos_enc\": pos,\n \"backbone_fpn\": features,\n }", "chunk_type": "class", "name": "ImageEncoder", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\encoders.py", "start_line": 441, "end_line": 512, "start_col": 0, "end_col": 9, "parent_name": null, "docstring": "Encode images using a trunk-neck architecture, producing multiscale features and positional encodings.\n\nThis class combines a trunk network for feature extraction with a neck network for feature refinement\nand positional encoding generation. It can optionally discard the lowest resolution features.\n\nAttributes:\n trunk (nn.Module): The trunk network for initial feature extraction.\n neck (nn.Module): The neck network for feature refinement and positional encoding generation.\n scalp (int): Number of lowest resolution feature levels to discard.\n\nMethods:\n forward: Process the input image through the trunk and neck networks.\n\nExamples:\n >>> trunk = SomeTrunkNetwork()\n >>> neck = SomeNeckNetwork()\n >>> encoder = ImageEncoder(trunk, neck, scalp=1)\n >>> image = torch.randn(1, 3, 224, 224)\n >>> output = encoder(image)\n >>> print(output.keys())\n dict_keys(['vision_features', 'vision_pos_enc', 'backbone_fpn'])", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.List", "typing.Optional", "typing.Tuple", "typing.Type", "torch", "torch.nn", "torch.nn.functional", "ultralytics.nn.modules.LayerNorm2d", "blocks.Block", "blocks.CXBlock", "blocks.Fuser", "blocks.MaskDownSampler", "blocks.MultiScaleBlock", "blocks.PatchEmbed", "blocks.PositionEmbeddingRandom", "blocks.PositionEmbeddingSine", "nn.Module" ], "chunk_id": "class_ImageEncoder_25645eaa" }, { "content": "class FpnNeck(nn.Module):\n \"\"\"\n A Feature Pyramid Network (FPN) neck variant for multiscale feature fusion in object detection models.\n\n This FPN variant removes the output convolution and uses bicubic interpolation for feature resizing,\n similar to ViT positional embedding interpolation.\n\n Attributes:\n position_encoding (PositionEmbeddingSine): Sinusoidal positional encoding module.\n convs (nn.ModuleList): List of convolutional layers for each backbone level.\n backbone_channel_list (List[int]): List of channel dimensions from the backbone.\n fpn_interp_model (str): Interpolation mode for FPN feature resizing.\n fuse_type (str): Type of feature fusion, either 'sum' or 'avg'.\n fpn_top_down_levels (List[int]): Levels to have top-down features in outputs.\n\n Methods:\n forward: Perform forward pass through the FPN neck.\n\n Examples:\n >>> backbone_channels = [64, 128, 256, 512]\n >>> fpn_neck = FpnNeck(256, backbone_channels)\n >>> inputs = [torch.rand(1, c, 32, 32) for c in backbone_channels]\n >>> outputs, positions = fpn_neck(inputs)\n >>> print(len(outputs), len(positions))\n 4 4\n \"\"\"\n\n def __init__(\n self,\n d_model: int,\n backbone_channel_list: List[int],\n kernel_size: int = 1,\n stride: int = 1,\n padding: int = 0,\n fpn_interp_model: str = \"bilinear\",\n fuse_type: str = \"sum\",\n fpn_top_down_levels: Optional[List[int]] = None,\n ):\n \"\"\"\n Initialize a modified Feature Pyramid Network (FPN) neck.\n\n This FPN variant removes the output convolution and uses bicubic interpolation for feature resizing,\n similar to ViT positional embedding interpolation.\n\n Args:\n d_model (int): Dimension of the model.\n backbone_channel_list (List[int]): List of channel dimensions from the backbone.\n kernel_size (int): Kernel size for the convolutional layers.\n stride (int): Stride for the convolutional layers.\n padding (int): Padding for the convolutional layers.\n fpn_interp_model (str): Interpolation mode for FPN feature resizing.\n fuse_type (str): Type of feature fusion, either 'sum' or 'avg'.\n fpn_top_down_levels (Optional[List[int]]): Levels to have top-down features in outputs.\n\n Examples:\n >>> backbone_channels = [64, 128, 256, 512]\n >>> fpn_neck = FpnNeck(256, backbone_channels)\n >>> print(fpn_neck)\n \"\"\"\n super().__init__()\n self.position_encoding = PositionEmbeddingSine(num_pos_feats=256)\n self.convs = nn.ModuleList()\n self.backbone_channel_list = backbone_channel_list\n for dim in backbone_channel_list:\n current = nn.Sequential()\n current.add_module(\n \"conv\",\n nn.Conv2d(\n in_channels=dim,\n out_channels=d_model,\n kernel_size=kernel_size,\n stride=stride,\n padding=padding,\n ),\n )\n\n self.convs.append(current)\n self.fpn_interp_model = fpn_interp_model\n assert fuse_type in {\"sum\", \"avg\"}\n self.fuse_type = fuse_type\n\n # Levels to have top-down features in its outputs\n # e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3\n # have top-down propagation, while outputs of level 0 and level 1 have only\n # lateral features from the same backbone level\n if fpn_top_down_levels is None:\n # Default is to have top-down features on all levels\n fpn_top_down_levels = range(len(self.convs))\n self.fpn_top_down_levels = list(fpn_top_down_levels)\n\n def forward(self, xs: List[torch.Tensor]):\n \"\"\"\n Perform forward pass through the Feature Pyramid Network (FPN) neck.\n\n This method processes a list of input tensors from the backbone through the FPN, applying lateral connections\n and top-down feature fusion. It generates output feature maps and corresponding positional encodings.\n\n Args:\n xs (List[torch.Tensor]): List of input tensors from the backbone, each with shape (B, C, H, W).\n\n Returns:\n out (List[torch.Tensor]): List of output feature maps after FPN processing, each with shape\n (B, d_model, H, W).\n pos (List[torch.Tensor]): List of positional encodings corresponding to each output feature map.\n\n Examples:\n >>> fpn_neck = FpnNeck(d_model=256, backbone_channel_list=[64, 128, 256, 512])\n >>> inputs = [torch.rand(1, c, 32, 32) for c in [64, 128, 256, 512]]\n >>> outputs, positions = fpn_neck(inputs)\n >>> print(len(outputs), len(positions))\n 4 4\n \"\"\"\n out = [None] * len(self.convs)\n pos = [None] * len(self.convs)\n assert len(xs) == len(self.convs)\n # FPN forward pass\n # see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py\n prev_features = None\n # Forward in top-down order (from low to high resolution)\n n = len(self.convs) - 1\n for i in range(n, -1, -1):\n x = xs[i]\n lateral_features = self.convs[n - i](x)\n if i in self.fpn_top_down_levels and prev_features is not None:\n top_down_features = F.interpolate(\n prev_features.to(dtype=torch.float32),\n scale_factor=2.0,\n mode=self.fpn_interp_model,\n align_corners=(None if self.fpn_interp_model == \"nearest\" else False),\n antialias=False,\n )\n prev_features = lateral_features + top_down_features\n if self.fuse_type == \"avg\":\n prev_features /= 2\n else:\n prev_features = lateral_features\n x_out = prev_features\n out[i] = x_out\n pos[i] = self.position_encoding(x_out).to(x_out.dtype)\n\n return out, pos", "chunk_type": "class", "name": "FpnNeck", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\encoders.py", "start_line": 515, "end_line": 655, "start_col": 0, "end_col": 23, "parent_name": null, "docstring": "A Feature Pyramid Network (FPN) neck variant for multiscale feature fusion in object detection models.\n\nThis FPN variant removes the output convolution and uses bicubic interpolation for feature resizing,\nsimilar to ViT positional embedding interpolation.\n\nAttributes:\n position_encoding (PositionEmbeddingSine): Sinusoidal positional encoding module.\n convs (nn.ModuleList): List of convolutional layers for each backbone level.\n backbone_channel_list (List[int]): List of channel dimensions from the backbone.\n fpn_interp_model (str): Interpolation mode for FPN feature resizing.\n fuse_type (str): Type of feature fusion, either 'sum' or 'avg'.\n fpn_top_down_levels (List[int]): Levels to have top-down features in outputs.\n\nMethods:\n forward: Perform forward pass through the FPN neck.\n\nExamples:\n >>> backbone_channels = [64, 128, 256, 512]\n >>> fpn_neck = FpnNeck(256, backbone_channels)\n >>> inputs = [torch.rand(1, c, 32, 32) for c in backbone_channels]\n >>> outputs, positions = fpn_neck(inputs)\n >>> print(len(outputs), len(positions))\n 4 4", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.List", "typing.Optional", "typing.Tuple", "typing.Type", "torch", "torch.nn", "torch.nn.functional", "ultralytics.nn.modules.LayerNorm2d", "blocks.Block", "blocks.CXBlock", "blocks.Fuser", "blocks.MaskDownSampler", "blocks.MultiScaleBlock", "blocks.PatchEmbed", "blocks.PositionEmbeddingRandom", "blocks.PositionEmbeddingSine", "nn.Module" ], "chunk_id": "class_FpnNeck_3bf0d7be" }, { "content": "class Hiera(nn.Module):\n \"\"\"\n Hierarchical vision transformer for efficient multiscale feature extraction in image processing tasks.\n\n This class implements a Hiera model, which is a hierarchical vision transformer architecture designed for\n efficient multiscale feature extraction. It uses a series of transformer blocks organized into stages,\n with optional pooling and global attention mechanisms.\n\n Attributes:\n window_spec (Tuple[int, ...]): Window sizes for each stage.\n q_stride (Tuple[int, int]): Downsampling stride between stages.\n stage_ends (List[int]): Indices of the last block in each stage.\n q_pool_blocks (List[int]): Indices of blocks where pooling is applied.\n return_interm_layers (bool): Whether to return intermediate layer outputs.\n patch_embed (PatchEmbed): Module for patch embedding.\n global_att_blocks (Tuple[int, ...]): Indices of blocks with global attention.\n window_pos_embed_bkg_spatial_size (Tuple[int, int]): Spatial size for window positional embedding background.\n pos_embed (nn.Parameter): Positional embedding for the background.\n pos_embed_window (nn.Parameter): Positional embedding for the window.\n blocks (nn.ModuleList): List of MultiScaleBlock modules.\n channel_list (List[int]): List of output channel dimensions for each stage.\n\n Methods:\n _get_pos_embed: Generate positional embeddings by interpolating and combining window and background embeddings.\n forward: Perform the forward pass through the Hiera model.\n\n Examples:\n >>> model = Hiera(embed_dim=96, num_heads=1, stages=(2, 3, 16, 3))\n >>> input_tensor = torch.randn(1, 3, 224, 224)\n >>> output_features = model(input_tensor)\n >>> for feat in output_features:\n ... print(feat.shape)\n \"\"\"\n\n def __init__(\n self,\n embed_dim: int = 96, # initial embed dim\n num_heads: int = 1, # initial number of heads\n drop_path_rate: float = 0.0, # stochastic depth\n q_pool: int = 3, # number of q_pool stages\n q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages\n stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage\n dim_mul: float = 2.0, # dim_mul factor at stage shift\n head_mul: float = 2.0, # head_mul factor at stage shift\n window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),\n # window size per stage, when not using global att.\n window_spec: Tuple[int, ...] = (\n 8,\n 4,\n 14,\n 7,\n ),\n # global attn in these blocks\n global_att_blocks: Tuple[int, ...] = (\n 12,\n 16,\n 20,\n ),\n return_interm_layers=True, # return feats from every stage\n ):\n \"\"\"\n Initialize a Hiera model, a hierarchical vision transformer for efficient multiscale feature extraction.\n\n Hiera is a hierarchical vision transformer architecture designed for efficient multiscale feature extraction\n in image processing tasks. It uses a series of transformer blocks organized into stages, with optional\n pooling and global attention mechanisms.\n\n Args:\n embed_dim (int): Initial embedding dimension for the model.\n num_heads (int): Initial number of attention heads.\n drop_path_rate (float): Stochastic depth rate.\n q_pool (int): Number of query pooling stages.\n q_stride (Tuple[int, int]): Downsampling stride between stages.\n stages (Tuple[int, ...]): Number of blocks per stage.\n dim_mul (float): Dimension multiplier factor at stage transitions.\n head_mul (float): Head multiplier factor at stage transitions.\n window_pos_embed_bkg_spatial_size (Tuple[int, int]): Spatial size for window positional embedding background.\n window_spec (Tuple[int, ...]): Window sizes for each stage when not using global attention.\n global_att_blocks (Tuple[int, ...]): Indices of blocks that use global attention.\n return_interm_layers (bool): Whether to return intermediate layer outputs.\n\n Examples:\n >>> model = Hiera(embed_dim=96, num_heads=1, stages=(2, 3, 16, 3))\n >>> input_tensor = torch.randn(1, 3, 224, 224)\n >>> output_features = model(input_tensor)\n >>> for feat in output_features:\n ... print(feat.shape)\n \"\"\"\n super().__init__()\n\n assert len(stages) == len(window_spec)\n self.window_spec = window_spec\n\n depth = sum(stages)\n self.q_stride = q_stride\n self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)]\n assert 0 <= q_pool <= len(self.stage_ends[:-1])\n self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool]\n self.return_interm_layers = return_interm_layers\n\n self.patch_embed = PatchEmbed(\n embed_dim=embed_dim,\n kernel_size=(7, 7),\n stride=(4, 4),\n padding=(3, 3),\n )\n # Which blocks have global attention?\n self.global_att_blocks = global_att_blocks\n\n # Windowed positional embedding (https://arxiv.org/abs/2311.05613)\n self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size\n self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size))\n self.pos_embed_window = nn.Parameter(torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0]))\n\n dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule\n\n cur_stage = 1\n self.blocks = nn.ModuleList()\n\n for i in range(depth):\n dim_out = embed_dim\n # Lags by a block, so first block of next stage uses an initial window size\n # of previous stage and final window size of current stage\n window_size = self.window_spec[cur_stage - 1]\n\n if self.global_att_blocks is not None:\n window_size = 0 if i in self.global_att_blocks else window_size\n\n if i - 1 in self.stage_ends:\n dim_out = int(embed_dim * dim_mul)\n num_heads = int(num_heads * head_mul)\n cur_stage += 1\n\n block = MultiScaleBlock(\n dim=embed_dim,\n dim_out=dim_out,\n num_heads=num_heads,\n drop_path=dpr[i],\n q_stride=self.q_stride if i in self.q_pool_blocks else None,\n window_size=window_size,\n )\n\n embed_dim = dim_out\n self.blocks.append(block)\n\n self.channel_list = (\n [self.blocks[i].dim_out for i in self.stage_ends[::-1]]\n if return_interm_layers\n else [self.blocks[-1].dim_out]\n )\n\n def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:\n \"\"\"Generate positional embeddings by interpolating and combining window and background embeddings.\"\"\"\n h, w = hw\n window_embed = self.pos_embed_window\n pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode=\"bicubic\")\n pos_embed = pos_embed + window_embed.tile([x // y for x, y in zip(pos_embed.shape, window_embed.shape)])\n pos_embed = pos_embed.permute(0, 2, 3, 1)\n return pos_embed\n\n def forward(self, x: torch.Tensor) -> List[torch.Tensor]:\n \"\"\"\n Perform forward pass through Hiera model, extracting multiscale features from input images.\n\n Args:\n x (torch.Tensor): Input tensor with shape (B, C, H, W) representing a batch of images.\n\n Returns:\n (List[torch.Tensor]): List of feature maps at different scales, each with shape (B, C_i, H_i, W_i), where\n C_i is the channel dimension and H_i, W_i are the spatial dimensions at scale i. The list is ordered\n from highest resolution (fine features) to lowest resolution (coarse features) if return_interm_layers\n is True, otherwise contains only the final output.\n\n Examples:\n >>> model = Hiera(embed_dim=96, num_heads=1, stages=(2, 3, 16, 3))\n >>> input_tensor = torch.randn(1, 3, 224, 224)\n >>> output_features = model(input_tensor)\n >>> for feat in output_features:\n ... print(feat.shape)\n \"\"\"\n x = self.patch_embed(x)\n # x: (B, H, W, C)\n\n # Add positional embedding\n x = x + self._get_pos_embed(x.shape[1:3])\n\n outputs = []\n for i, blk in enumerate(self.blocks):\n x = blk(x)\n if (i == self.stage_ends[-1]) or (i in self.stage_ends and self.return_interm_layers):\n feats = x.permute(0, 3, 1, 2)\n outputs.append(feats)\n\n return outputs", "chunk_type": "class", "name": "Hiera", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\encoders.py", "start_line": 658, "end_line": 851, "start_col": 0, "end_col": 22, "parent_name": null, "docstring": "Hierarchical vision transformer for efficient multiscale feature extraction in image processing tasks.\n\nThis class implements a Hiera model, which is a hierarchical vision transformer architecture designed for\nefficient multiscale feature extraction. It uses a series of transformer blocks organized into stages,\nwith optional pooling and global attention mechanisms.\n\nAttributes:\n window_spec (Tuple[int, ...]): Window sizes for each stage.\n q_stride (Tuple[int, int]): Downsampling stride between stages.\n stage_ends (List[int]): Indices of the last block in each stage.\n q_pool_blocks (List[int]): Indices of blocks where pooling is applied.\n return_interm_layers (bool): Whether to return intermediate layer outputs.\n patch_embed (PatchEmbed): Module for patch embedding.\n global_att_blocks (Tuple[int, ...]): Indices of blocks with global attention.\n window_pos_embed_bkg_spatial_size (Tuple[int, int]): Spatial size for window positional embedding background.\n pos_embed (nn.Parameter): Positional embedding for the background.\n pos_embed_window (nn.Parameter): Positional embedding for the window.\n blocks (nn.ModuleList): List of MultiScaleBlock modules.\n channel_list (List[int]): List of output channel dimensions for each stage.\n\nMethods:\n _get_pos_embed: Generate positional embeddings by interpolating and combining window and background embeddings.\n forward: Perform the forward pass through the Hiera model.\n\nExamples:\n >>> model = Hiera(embed_dim=96, num_heads=1, stages=(2, 3, 16, 3))\n >>> input_tensor = torch.randn(1, 3, 224, 224)\n >>> output_features = model(input_tensor)\n >>> for feat in output_features:\n ... print(feat.shape)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.List", "typing.Optional", "typing.Tuple", "typing.Type", "torch", "torch.nn", "torch.nn.functional", "ultralytics.nn.modules.LayerNorm2d", "blocks.Block", "blocks.CXBlock", "blocks.Fuser", "blocks.MaskDownSampler", "blocks.MultiScaleBlock", "blocks.PatchEmbed", "blocks.PositionEmbeddingRandom", "blocks.PositionEmbeddingSine", "nn.Module" ], "chunk_id": "class_Hiera_8bb910ce" }, { "content": "import copy", "chunk_type": "import", "name": "copy", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\memory_attention.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 11, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_copy_e7340124" }, { "content": "from typing import Optional", "chunk_type": "import", "name": "Optional", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\memory_attention.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 27, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Optional_9b86dd1e" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\memory_attention.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_9d64bdea" }, { "content": "from torch import nn", "chunk_type": "import", "name": "nn", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\memory_attention.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 20, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_nn_890e43f1" }, { "content": "from .blocks import RoPEAttention", "chunk_type": "import", "name": "RoPEAttention", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\memory_attention.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 33, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_RoPEAttention_4b80ab1c" }, { "content": "class MemoryAttentionLayer(nn.Module):\n \"\"\"\n Implements a memory attention layer with self-attention and cross-attention mechanisms for neural networks.\n\n This class combines self-attention, cross-attention, and feedforward components to process input tensors and\n generate memory-based attention outputs.\n\n Attributes:\n d_model (int): Dimensionality of the model.\n dim_feedforward (int): Dimensionality of the feedforward network.\n dropout_value (float): Dropout rate for regularization.\n self_attn (RoPEAttention): Self-attention mechanism using RoPE (Rotary Position Embedding).\n cross_attn_image (RoPEAttention): Cross-attention mechanism for image processing.\n linear1 (nn.Linear): First linear layer of the feedforward network.\n linear2 (nn.Linear): Second linear layer of the feedforward network.\n norm1 (nn.LayerNorm): Layer normalization for self-attention output.\n norm2 (nn.LayerNorm): Layer normalization for cross-attention output.\n norm3 (nn.LayerNorm): Layer normalization for feedforward network output.\n dropout1 (nn.Dropout): Dropout layer after self-attention.\n dropout2 (nn.Dropout): Dropout layer after cross-attention.\n dropout3 (nn.Dropout): Dropout layer after feedforward network.\n activation (nn.ReLU): Activation function for the feedforward network.\n pos_enc_at_attn (bool): Flag to add positional encoding at attention.\n pos_enc_at_cross_attn_queries (bool): Flag to add positional encoding to cross-attention queries.\n pos_enc_at_cross_attn_keys (bool): Flag to add positional encoding to cross-attention keys.\n\n Methods:\n forward: Performs the full memory attention operation on input tensors.\n _forward_sa: Performs self-attention on input tensor.\n _forward_ca: Performs cross-attention between target and memory tensors.\n\n Examples:\n >>> layer = MemoryAttentionLayer(d_model=256, dim_feedforward=2048, dropout=0.1)\n >>> tgt = torch.randn(1, 100, 256)\n >>> memory = torch.randn(1, 100, 64)\n >>> pos = torch.randn(1, 100, 256)\n >>> query_pos = torch.randn(1, 100, 256)\n >>> output = layer(tgt, memory, pos, query_pos)\n >>> print(output.shape)\n torch.Size([1, 100, 256])\n \"\"\"\n\n def __init__(\n self,\n d_model: int = 256,\n dim_feedforward: int = 2048,\n dropout: float = 0.1,\n pos_enc_at_attn: bool = False,\n pos_enc_at_cross_attn_keys: bool = True,\n pos_enc_at_cross_attn_queries: bool = False,\n ):\n \"\"\"\n Initialize a memory attention layer with self-attention, cross-attention, and feedforward components.\n\n Args:\n d_model (int): Dimensionality of the model.\n dim_feedforward (int): Dimensionality of the feedforward network.\n dropout (float): Dropout rate for regularization.\n pos_enc_at_attn (bool): Whether to add positional encoding at attention.\n pos_enc_at_cross_attn_keys (bool): Whether to add positional encoding to cross-attention keys.\n pos_enc_at_cross_attn_queries (bool): Whether to add positional encoding to cross-attention queries.\n \"\"\"\n super().__init__()\n self.d_model = d_model\n self.dim_feedforward = dim_feedforward\n self.dropout_value = dropout\n self.self_attn = RoPEAttention(embedding_dim=256, num_heads=1, downsample_rate=1)\n self.cross_attn_image = RoPEAttention(\n rope_k_repeat=True,\n embedding_dim=256,\n num_heads=1,\n downsample_rate=1,\n kv_in_dim=64,\n )\n\n # Implementation of Feedforward model\n self.linear1 = nn.Linear(d_model, dim_feedforward)\n self.dropout = nn.Dropout(dropout)\n self.linear2 = nn.Linear(dim_feedforward, d_model)\n\n self.norm1 = nn.LayerNorm(d_model)\n self.norm2 = nn.LayerNorm(d_model)\n self.norm3 = nn.LayerNorm(d_model)\n self.dropout1 = nn.Dropout(dropout)\n self.dropout2 = nn.Dropout(dropout)\n self.dropout3 = nn.Dropout(dropout)\n\n self.activation = nn.ReLU()\n\n # Where to add pos enc\n self.pos_enc_at_attn = pos_enc_at_attn\n self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries\n self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys\n\n def _forward_sa(self, tgt: torch.Tensor, query_pos: Optional[torch.Tensor]) -> torch.Tensor:\n \"\"\"Perform self-attention on input tensor using positional encoding and RoPE attention mechanism.\"\"\"\n tgt2 = self.norm1(tgt)\n q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2\n tgt2 = self.self_attn(q, k, v=tgt2)\n tgt = tgt + self.dropout1(tgt2)\n return tgt\n\n def _forward_ca(\n self,\n tgt: torch.Tensor,\n memory: torch.Tensor,\n query_pos: Optional[torch.Tensor],\n pos: Optional[torch.Tensor],\n num_k_exclude_rope: int = 0,\n ) -> torch.Tensor:\n \"\"\"Perform cross-attention between target and memory tensors using RoPEAttention mechanism.\"\"\"\n kwds = {}\n if num_k_exclude_rope > 0:\n assert isinstance(self.cross_attn_image, RoPEAttention)\n kwds = {\"num_k_exclude_rope\": num_k_exclude_rope}\n\n # Cross-Attention\n tgt2 = self.norm2(tgt)\n tgt2 = self.cross_attn_image(\n q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2,\n k=memory + pos if self.pos_enc_at_cross_attn_keys else memory,\n v=memory,\n **kwds,\n )\n tgt = tgt + self.dropout2(tgt2)\n return tgt\n\n def forward(\n self,\n tgt: torch.Tensor,\n memory: torch.Tensor,\n pos: Optional[torch.Tensor] = None,\n query_pos: Optional[torch.Tensor] = None,\n num_k_exclude_rope: int = 0,\n ) -> torch.Tensor:\n \"\"\"\n Process input tensors through self-attention, cross-attention, and feedforward network layers.\n\n Args:\n tgt (torch.Tensor): Target tensor for self-attention with shape (N, L, D).\n memory (torch.Tensor): Memory tensor for cross-attention with shape (N, S, D).\n pos (Optional[torch.Tensor]): Positional encoding for memory tensor.\n query_pos (Optional[torch.Tensor]): Positional encoding for target tensor.\n num_k_exclude_rope (int): Number of keys to exclude from rotary position embedding.\n\n Returns:\n (torch.Tensor): Processed tensor after attention and feedforward layers with shape (N, L, D).\n \"\"\"\n tgt = self._forward_sa(tgt, query_pos)\n tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope)\n # MLP\n tgt2 = self.norm3(tgt)\n tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))\n tgt = tgt + self.dropout3(tgt2)\n return tgt", "chunk_type": "class", "name": "MemoryAttentionLayer", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\memory_attention.py", "start_line": 12, "end_line": 166, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": "Implements a memory attention layer with self-attention and cross-attention mechanisms for neural networks.\n\nThis class combines self-attention, cross-attention, and feedforward components to process input tensors and\ngenerate memory-based attention outputs.\n\nAttributes:\n d_model (int): Dimensionality of the model.\n dim_feedforward (int): Dimensionality of the feedforward network.\n dropout_value (float): Dropout rate for regularization.\n self_attn (RoPEAttention): Self-attention mechanism using RoPE (Rotary Position Embedding).\n cross_attn_image (RoPEAttention): Cross-attention mechanism for image processing.\n linear1 (nn.Linear): First linear layer of the feedforward network.\n linear2 (nn.Linear): Second linear layer of the feedforward network.\n norm1 (nn.LayerNorm): Layer normalization for self-attention output.\n norm2 (nn.LayerNorm): Layer normalization for cross-attention output.\n norm3 (nn.LayerNorm): Layer normalization for feedforward network output.\n dropout1 (nn.Dropout): Dropout layer after self-attention.\n dropout2 (nn.Dropout): Dropout layer after cross-attention.\n dropout3 (nn.Dropout): Dropout layer after feedforward network.\n activation (nn.ReLU): Activation function for the feedforward network.\n pos_enc_at_attn (bool): Flag to add positional encoding at attention.\n pos_enc_at_cross_attn_queries (bool): Flag to add positional encoding to cross-attention queries.\n pos_enc_at_cross_attn_keys (bool): Flag to add positional encoding to cross-attention keys.\n\nMethods:\n forward: Performs the full memory attention operation on input tensors.\n _forward_sa: Performs self-attention on input tensor.\n _forward_ca: Performs cross-attention between target and memory tensors.\n\nExamples:\n >>> layer = MemoryAttentionLayer(d_model=256, dim_feedforward=2048, dropout=0.1)\n >>> tgt = torch.randn(1, 100, 256)\n >>> memory = torch.randn(1, 100, 64)\n >>> pos = torch.randn(1, 100, 256)\n >>> query_pos = torch.randn(1, 100, 256)\n >>> output = layer(tgt, memory, pos, query_pos)\n >>> print(output.shape)\n torch.Size([1, 100, 256])", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "copy", "typing.Optional", "torch", "torch.nn", "blocks.RoPEAttention", "nn.Module" ], "chunk_id": "class_MemoryAttentionLayer_9da47992" }, { "content": "class MemoryAttention(nn.Module):\n \"\"\"\n Memory attention module for processing sequential data with self and cross-attention mechanisms.\n\n This class implements a multi-layer attention mechanism that combines self-attention and cross-attention\n for processing sequential data, particularly useful in transformer-like architectures.\n\n Attributes:\n d_model (int): The dimension of the model's hidden state.\n layers (nn.ModuleList): A list of MemoryAttentionLayer modules.\n num_layers (int): The number of attention layers.\n norm (nn.LayerNorm): Layer normalization applied to the output.\n pos_enc_at_input (bool): Whether to apply positional encoding at the input.\n batch_first (bool): Whether the input tensors are in batch-first format.\n\n Methods:\n forward: Processes input tensors through the attention layers.\n\n Examples:\n >>> d_model = 256\n >>> layer = MemoryAttentionLayer(d_model)\n >>> attention = MemoryAttention(d_model, pos_enc_at_input=True, layer=layer, num_layers=3)\n >>> curr = torch.randn(10, 32, d_model) # (seq_len, batch_size, d_model)\n >>> memory = torch.randn(20, 32, d_model) # (mem_len, batch_size, d_model)\n >>> curr_pos = torch.randn(10, 32, d_model)\n >>> memory_pos = torch.randn(20, 32, d_model)\n >>> output = attention(curr, memory, curr_pos, memory_pos)\n >>> print(output.shape)\n torch.Size([10, 32, 256])\n \"\"\"\n\n def __init__(\n self,\n d_model: int,\n pos_enc_at_input: bool,\n layer: nn.Module,\n num_layers: int,\n batch_first: bool = True, # Do layers expect batch first input?\n ):\n \"\"\"\n Initialize MemoryAttention with specified layers and normalization for sequential data processing.\n\n This class implements a multi-layer attention mechanism that combines self-attention and cross-attention\n for processing sequential data, particularly useful in transformer-like architectures.\n\n Args:\n d_model (int): The dimension of the model's hidden state.\n pos_enc_at_input (bool): Whether to apply positional encoding at the input.\n layer (nn.Module): The attention layer to be used in the module.\n num_layers (int): The number of attention layers.\n batch_first (bool): Whether the input tensors are in batch-first format.\n\n Examples:\n >>> d_model = 256\n >>> layer = MemoryAttentionLayer(d_model)\n >>> attention = MemoryAttention(d_model, pos_enc_at_input=True, layer=layer, num_layers=3)\n >>> curr = torch.randn(10, 32, d_model) # (seq_len, batch_size, d_model)\n >>> memory = torch.randn(20, 32, d_model) # (mem_len, batch_size, d_model)\n >>> curr_pos = torch.randn(10, 32, d_model)\n >>> memory_pos = torch.randn(20, 32, d_model)\n >>> output = attention(curr, memory, curr_pos, memory_pos)\n >>> print(output.shape)\n torch.Size([10, 32, 256])\n \"\"\"\n super().__init__()\n self.d_model = d_model\n self.layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(num_layers)])\n self.num_layers = num_layers\n self.norm = nn.LayerNorm(d_model)\n self.pos_enc_at_input = pos_enc_at_input\n self.batch_first = batch_first\n\n def forward(\n self,\n curr: torch.Tensor, # self-attention inputs\n memory: torch.Tensor, # cross-attention inputs\n curr_pos: Optional[torch.Tensor] = None, # pos_enc for self-attention inputs\n memory_pos: Optional[torch.Tensor] = None, # pos_enc for cross-attention inputs\n num_obj_ptr_tokens: int = 0, # number of object pointer *tokens*\n ) -> torch.Tensor:\n \"\"\"\n Process inputs through attention layers, applying self and cross-attention with positional encoding.\n\n Args:\n curr (torch.Tensor): Self-attention input tensor, representing the current state.\n memory (torch.Tensor): Cross-attention input tensor, representing memory information.\n curr_pos (Optional[torch.Tensor]): Positional encoding for self-attention inputs.\n memory_pos (Optional[torch.Tensor]): Positional encoding for cross-attention inputs.\n num_obj_ptr_tokens (int): Number of object pointer tokens to exclude from rotary position embedding.\n\n Returns:\n (torch.Tensor): Processed output tensor after applying attention layers and normalization.\n\n Examples:\n >>> d_model = 256\n >>> layer = MemoryAttentionLayer(d_model)\n >>> attention = MemoryAttention(d_model, pos_enc_at_input=True, layer=layer, num_layers=3)\n >>> curr = torch.randn(10, 32, d_model) # (seq_len, batch_size, d_model)\n >>> memory = torch.randn(20, 32, d_model) # (mem_len, batch_size, d_model)\n >>> curr_pos = torch.randn(10, 32, d_model)\n >>> memory_pos = torch.randn(20, 32, d_model)\n >>> output = attention(curr, memory, curr_pos, memory_pos)\n >>> print(output.shape)\n torch.Size([10, 32, 256])\n \"\"\"\n if isinstance(curr, list):\n assert isinstance(curr_pos, list)\n assert len(curr) == len(curr_pos) == 1\n curr, curr_pos = curr[0], curr_pos[0]\n\n assert curr.shape[1] == memory.shape[1], \"Batch size must be the same for curr and memory\"\n\n output = curr\n if self.pos_enc_at_input and curr_pos is not None:\n output = output + 0.1 * curr_pos\n\n if self.batch_first:\n # Convert to batch first\n output = output.transpose(0, 1)\n curr_pos = curr_pos.transpose(0, 1)\n memory = memory.transpose(0, 1)\n memory_pos = memory_pos.transpose(0, 1)\n\n for layer in self.layers:\n kwds = {}\n if isinstance(layer.cross_attn_image, RoPEAttention):\n kwds = {\"num_k_exclude_rope\": num_obj_ptr_tokens}\n\n output = layer(\n tgt=output,\n memory=memory,\n pos=memory_pos,\n query_pos=curr_pos,\n **kwds,\n )\n normed_output = self.norm(output)\n\n if self.batch_first:\n # Convert back to seq first\n normed_output = normed_output.transpose(0, 1)\n curr_pos = curr_pos.transpose(0, 1)\n\n return normed_output", "chunk_type": "class", "name": "MemoryAttention", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\memory_attention.py", "start_line": 169, "end_line": 311, "start_col": 0, "end_col": 28, "parent_name": null, "docstring": "Memory attention module for processing sequential data with self and cross-attention mechanisms.\n\nThis class implements a multi-layer attention mechanism that combines self-attention and cross-attention\nfor processing sequential data, particularly useful in transformer-like architectures.\n\nAttributes:\n d_model (int): The dimension of the model's hidden state.\n layers (nn.ModuleList): A list of MemoryAttentionLayer modules.\n num_layers (int): The number of attention layers.\n norm (nn.LayerNorm): Layer normalization applied to the output.\n pos_enc_at_input (bool): Whether to apply positional encoding at the input.\n batch_first (bool): Whether the input tensors are in batch-first format.\n\nMethods:\n forward: Processes input tensors through the attention layers.\n\nExamples:\n >>> d_model = 256\n >>> layer = MemoryAttentionLayer(d_model)\n >>> attention = MemoryAttention(d_model, pos_enc_at_input=True, layer=layer, num_layers=3)\n >>> curr = torch.randn(10, 32, d_model) # (seq_len, batch_size, d_model)\n >>> memory = torch.randn(20, 32, d_model) # (mem_len, batch_size, d_model)\n >>> curr_pos = torch.randn(10, 32, d_model)\n >>> memory_pos = torch.randn(20, 32, d_model)\n >>> output = attention(curr, memory, curr_pos, memory_pos)\n >>> print(output.shape)\n torch.Size([10, 32, 256])", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "copy", "typing.Optional", "torch", "torch.nn", "blocks.RoPEAttention", "nn.Module" ], "chunk_id": "class_MemoryAttention_82ab8fdf" }, { "content": "from typing import List", "chunk_type": "import", "name": "List", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\sam.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 23, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_List_d360fbfb" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\sam.py", "start_line": 11, "end_line": 11, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_b9f1e54d" }, { "content": "import torch.nn.functional as F", "chunk_type": "import", "name": "torch.nn.functional", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\sam.py", "start_line": 12, "end_line": 12, "start_col": 0, "end_col": 31, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch.nn.functional_5af2f3e6" }, { "content": "from torch import nn", "chunk_type": "import", "name": "nn", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\sam.py", "start_line": 13, "end_line": 13, "start_col": 0, "end_col": 20, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_nn_fa6e5c75" }, { "content": "from torch.nn.init import trunc_normal_", "chunk_type": "import", "name": "trunc_normal_", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\sam.py", "start_line": 14, "end_line": 14, "start_col": 0, "end_col": 39, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_trunc_normal__e8ba48f7" }, { "content": "from ultralytics.nn.modules import MLP", "chunk_type": "import", "name": "MLP", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\sam.py", "start_line": 16, "end_line": 16, "start_col": 0, "end_col": 38, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_MLP_7c9a62a8" }, { "content": "from ultralytics.utils import LOGGER", "chunk_type": "import", "name": "LOGGER", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\sam.py", "start_line": 17, "end_line": 17, "start_col": 0, "end_col": 36, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_LOGGER_fde5fd72" }, { "content": "from .blocks import SAM2TwoWayTransformer", "chunk_type": "import", "name": "SAM2TwoWayTransformer", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\sam.py", "start_line": 19, "end_line": 19, "start_col": 0, "end_col": 41, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_SAM2TwoWayTransformer_ee3bfbb6" }, { "content": "from .decoders import MaskDecoder, SAM2MaskDecoder", "chunk_type": "import", "name": "MaskDecoder, SAM2MaskDecoder", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\sam.py", "start_line": 20, "end_line": 20, "start_col": 0, "end_col": 50, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_MaskDecoder, SAM2MaskDecoder_753b2615" }, { "content": "from .encoders import ImageEncoderViT, PromptEncoder", "chunk_type": "import", "name": "ImageEncoderViT, PromptEncoder", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\sam.py", "start_line": 21, "end_line": 21, "start_col": 0, "end_col": 52, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_ImageEncoderViT, PromptEncoder_47cfddc7" }, { "content": "from .utils import get_1d_sine_pe, select_closest_cond_frames", "chunk_type": "import", "name": "get_1d_sine_pe, select_closest_cond_frames", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\sam.py", "start_line": 22, "end_line": 22, "start_col": 0, "end_col": 61, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_get_1d_sine_pe, select_closest_cond_frames_b5ce41f7" }, { "content": "NO_OBJ_SCORE = -1024.0", "chunk_type": "variable", "name": "NO_OBJ_SCORE", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\sam.py", "start_line": 25, "end_line": 25, "start_col": 0, "end_col": 22, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_NO_OBJ_SCORE_03a1cbc0" }, { "content": "class SAMModel(nn.Module):\n \"\"\"\n Segment Anything Model (SAM) for object segmentation tasks.\n\n This class combines image encoders, prompt encoders, and mask decoders to predict object masks from images\n and input prompts.\n\n Attributes:\n mask_threshold (float): Threshold value for mask prediction.\n image_encoder (ImageEncoderViT): Backbone for encoding images into embeddings.\n prompt_encoder (PromptEncoder): Encoder for various types of input prompts.\n mask_decoder (MaskDecoder): Predicts object masks from image and prompt embeddings.\n pixel_mean (torch.Tensor): Mean values for normalizing pixels in the input image.\n pixel_std (torch.Tensor): Standard deviation values for normalizing pixels in the input image.\n\n Methods:\n set_imgsz: Set image size to make model compatible with different image sizes.\n\n Examples:\n >>> image_encoder = ImageEncoderViT(...)\n >>> prompt_encoder = PromptEncoder(...)\n >>> mask_decoder = MaskDecoder(...)\n >>> sam_model = SAMModel(image_encoder, prompt_encoder, mask_decoder)\n >>> # Further usage depends on SAMPredictor class\n\n Notes:\n All forward() operations are implemented in the SAMPredictor class.\n \"\"\"\n\n mask_threshold: float = 0.0\n\n def __init__(\n self,\n image_encoder: ImageEncoderViT,\n prompt_encoder: PromptEncoder,\n mask_decoder: MaskDecoder,\n pixel_mean: List[float] = (123.675, 116.28, 103.53),\n pixel_std: List[float] = (58.395, 57.12, 57.375),\n ) -> None:\n \"\"\"\n Initialize the SAMModel class to predict object masks from an image and input prompts.\n\n Args:\n image_encoder (ImageEncoderViT): The backbone used to encode the image into image embeddings.\n prompt_encoder (PromptEncoder): Encodes various types of input prompts.\n mask_decoder (MaskDecoder): Predicts masks from the image embeddings and encoded prompts.\n pixel_mean (List[float]): Mean values for normalizing pixels in the input image.\n pixel_std (List[float]): Standard deviation values for normalizing pixels in the input image.\n\n Examples:\n >>> image_encoder = ImageEncoderViT(...)\n >>> prompt_encoder = PromptEncoder(...)\n >>> mask_decoder = MaskDecoder(...)\n >>> sam_model = SAMModel(image_encoder, prompt_encoder, mask_decoder)\n >>> # Further usage depends on SAMPredictor class\n\n Notes:\n All forward() operations moved to SAMPredictor.\n \"\"\"\n super().__init__()\n self.image_encoder = image_encoder\n self.prompt_encoder = prompt_encoder\n self.mask_decoder = mask_decoder\n self.register_buffer(\"pixel_mean\", torch.Tensor(pixel_mean).view(-1, 1, 1), False)\n self.register_buffer(\"pixel_std\", torch.Tensor(pixel_std).view(-1, 1, 1), False)\n\n def set_imgsz(self, imgsz):\n \"\"\"Set image size to make model compatible with different image sizes.\"\"\"\n if hasattr(self.image_encoder, \"set_imgsz\"):\n self.image_encoder.set_imgsz(imgsz)\n self.prompt_encoder.input_image_size = imgsz\n self.prompt_encoder.image_embedding_size = [x // 16 for x in imgsz] # 16 is fixed as patch size of ViT model\n self.image_encoder.img_size = imgsz[0]", "chunk_type": "class", "name": "SAMModel", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\sam.py", "start_line": 28, "end_line": 100, "start_col": 0, "end_col": 46, "parent_name": null, "docstring": "Segment Anything Model (SAM) for object segmentation tasks.\n\nThis class combines image encoders, prompt encoders, and mask decoders to predict object masks from images\nand input prompts.\n\nAttributes:\n mask_threshold (float): Threshold value for mask prediction.\n image_encoder (ImageEncoderViT): Backbone for encoding images into embeddings.\n prompt_encoder (PromptEncoder): Encoder for various types of input prompts.\n mask_decoder (MaskDecoder): Predicts object masks from image and prompt embeddings.\n pixel_mean (torch.Tensor): Mean values for normalizing pixels in the input image.\n pixel_std (torch.Tensor): Standard deviation values for normalizing pixels in the input image.\n\nMethods:\n set_imgsz: Set image size to make model compatible with different image sizes.\n\nExamples:\n >>> image_encoder = ImageEncoderViT(...)\n >>> prompt_encoder = PromptEncoder(...)\n >>> mask_decoder = MaskDecoder(...)\n >>> sam_model = SAMModel(image_encoder, prompt_encoder, mask_decoder)\n >>> # Further usage depends on SAMPredictor class\n\nNotes:\n All forward() operations are implemented in the SAMPredictor class.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.List", "torch", "torch.nn.functional", "torch.nn", "torch.nn.init.trunc_normal_", "ultralytics.nn.modules.MLP", "ultralytics.utils.LOGGER", "blocks.SAM2TwoWayTransformer", "decoders.MaskDecoder", "decoders.SAM2MaskDecoder", "encoders.ImageEncoderViT", "encoders.PromptEncoder", "utils.get_1d_sine_pe", "utils.select_closest_cond_frames", "nn.Module" ], "chunk_id": "class_SAMModel_74592926" }, { "content": "class SAM2Model(torch.nn.Module):\n \"\"\"\n SAM2Model class for Segment Anything Model 2 with memory-based video object segmentation capabilities.\n\n This class extends the functionality of SAM to handle video sequences, incorporating memory mechanisms\n for temporal consistency and efficient tracking of objects across frames.\n\n Attributes:\n mask_threshold (float): Threshold value for mask prediction.\n image_encoder (ImageEncoderViT): Visual encoder for extracting image features.\n memory_attention (nn.Module): Module for attending to memory features.\n memory_encoder (nn.Module): Encoder for generating memory representations.\n num_maskmem (int): Number of accessible memory frames.\n image_size (int): Size of input images.\n backbone_stride (int): Stride of the backbone network output.\n sam_prompt_embed_dim (int): Dimension of SAM prompt embeddings.\n sam_image_embedding_size (int): Size of SAM image embeddings.\n sam_prompt_encoder (PromptEncoder): Encoder for processing input prompts.\n sam_mask_decoder (SAM2MaskDecoder): Decoder for generating object masks.\n obj_ptr_proj (nn.Module): Projection layer for object pointers.\n obj_ptr_tpos_proj (nn.Module): Projection for temporal positional encoding in object pointers.\n hidden_dim (int): Hidden dimension of the model.\n mem_dim (int): Memory dimension for encoding features.\n use_high_res_features_in_sam (bool): Whether to use high-resolution feature maps in the SAM mask decoder.\n use_obj_ptrs_in_encoder (bool): Whether to cross-attend to object pointers from other frames in the encoder.\n max_obj_ptrs_in_encoder (int): Maximum number of object pointers from other frames in encoder cross-attention.\n add_tpos_enc_to_obj_ptrs (bool): Whether to add temporal positional encoding to object pointers.\n proj_tpos_enc_in_obj_ptrs (bool): Whether to add an extra linear projection layer for temporal positional\n encoding in object pointers.\n use_signed_tpos_enc_to_obj_ptrs (bool): Whether to use signed distance in temporal positional encoding.\n only_obj_ptrs_in_the_past_for_eval (bool): Whether to only attend to object pointers in the past during\n evaluation.\n pred_obj_scores (bool): Whether to predict if there is an object in the frame.\n pred_obj_scores_mlp (bool): Whether to use an MLP to predict object scores.\n fixed_no_obj_ptr (bool): Whether to have a fixed no-object pointer when there is no object present.\n soft_no_obj_ptr (bool): Whether to mix in no-object pointer softly for easier recovery and error mitigation.\n use_mlp_for_obj_ptr_proj (bool): Whether to use MLP for object pointer projection.\n no_obj_embed_spatial (torch.Tensor | None): No-object embedding for spatial frames.\n max_cond_frames_in_attn (int): Maximum number of conditioning frames to participate in memory attention.\n directly_add_no_mem_embed (bool): Whether to directly add no-memory embedding to image feature on the\n first frame.\n multimask_output_in_sam (bool): Whether to output multiple masks for the first click on initial\n conditioning frames.\n multimask_min_pt_num (int): Minimum number of clicks to use multimask output in SAM.\n multimask_max_pt_num (int): Maximum number of clicks to use multimask output in SAM.\n multimask_output_for_tracking (bool): Whether to use multimask output for tracking.\n use_multimask_token_for_obj_ptr (bool): Whether to use multimask tokens for object pointers.\n iou_prediction_use_sigmoid (bool): Whether to use sigmoid to restrict IoU prediction to [0-1].\n memory_temporal_stride_for_eval (int): Memory bank's temporal stride during evaluation.\n non_overlap_masks_for_mem_enc (bool): Whether to apply non-overlapping constraints on object masks in\n memory encoder during evaluation.\n sigmoid_scale_for_mem_enc (float): Scale factor for mask sigmoid probability.\n sigmoid_bias_for_mem_enc (float): Bias factor for mask sigmoid probability.\n binarize_mask_from_pts_for_mem_enc (bool): Whether to binarize sigmoid mask logits on interacted frames\n with clicks during evaluation.\n use_mask_input_as_output_without_sam (bool): Whether to directly output the input mask without using SAM\n prompt encoder and mask decoder on frames with mask input.\n\n Methods:\n forward_image: Process image batch through encoder to extract multi-level features.\n track_step: Perform a single tracking step, updating object masks and memory features.\n set_binarize: Set binarize for VideoPredictor.\n set_imgsz: Set image size to make model compatible with different image sizes.\n\n Examples:\n >>> model = SAM2Model(image_encoder, memory_attention, memory_encoder)\n >>> image_batch = torch.rand(1, 3, 512, 512)\n >>> features = model.forward_image(image_batch)\n >>> track_results = model.track_step(0, True, features, None, None, None, {})\n \"\"\"\n\n mask_threshold: float = 0.0\n\n def __init__(\n self,\n image_encoder,\n memory_attention,\n memory_encoder,\n num_maskmem=7,\n image_size=512,\n backbone_stride=16,\n sigmoid_scale_for_mem_enc=1.0,\n sigmoid_bias_for_mem_enc=0.0,\n binarize_mask_from_pts_for_mem_enc=False,\n use_mask_input_as_output_without_sam=False,\n max_cond_frames_in_attn=-1,\n directly_add_no_mem_embed=False,\n use_high_res_features_in_sam=False,\n multimask_output_in_sam=False,\n multimask_min_pt_num=1,\n multimask_max_pt_num=1,\n multimask_output_for_tracking=False,\n use_multimask_token_for_obj_ptr: bool = False,\n iou_prediction_use_sigmoid=False,\n memory_temporal_stride_for_eval=1,\n non_overlap_masks_for_mem_enc=False,\n use_obj_ptrs_in_encoder=False,\n max_obj_ptrs_in_encoder=16,\n add_tpos_enc_to_obj_ptrs=True,\n proj_tpos_enc_in_obj_ptrs=False,\n use_signed_tpos_enc_to_obj_ptrs=False,\n only_obj_ptrs_in_the_past_for_eval=False,\n pred_obj_scores: bool = False,\n pred_obj_scores_mlp: bool = False,\n fixed_no_obj_ptr: bool = False,\n soft_no_obj_ptr: bool = False,\n use_mlp_for_obj_ptr_proj: bool = False,\n no_obj_embed_spatial: bool = False,\n sam_mask_decoder_extra_args=None,\n compile_image_encoder: bool = False,\n ):\n \"\"\"\n Initialize the SAM2Model for video object segmentation with memory-based tracking.\n\n Args:\n image_encoder (nn.Module): Visual encoder for extracting image features.\n memory_attention (nn.Module): Module for attending to memory features.\n memory_encoder (nn.Module): Encoder for generating memory representations.\n num_maskmem (int): Number of accessible memory frames.\n image_size (int): Size of input images.\n backbone_stride (int): Stride of the image backbone output.\n sigmoid_scale_for_mem_enc (float): Scale factor for mask sigmoid probability.\n sigmoid_bias_for_mem_enc (float): Bias factor for mask sigmoid probability.\n binarize_mask_from_pts_for_mem_enc (bool): Whether to binarize sigmoid mask logits on interacted frames\n with clicks during evaluation.\n use_mask_input_as_output_without_sam (bool): Whether to directly output the input mask without using SAM\n prompt encoder and mask decoder on frames with mask input.\n max_cond_frames_in_attn (int): Maximum number of conditioning frames to participate in memory attention.\n directly_add_no_mem_embed (bool): Whether to directly add no-memory embedding to image feature on the\n first frame.\n use_high_res_features_in_sam (bool): Whether to use high-resolution feature maps in the SAM mask decoder.\n multimask_output_in_sam (bool): Whether to output multiple masks for the first click on initial\n conditioning frames.\n multimask_min_pt_num (int): Minimum number of clicks to use multimask output in SAM.\n multimask_max_pt_num (int): Maximum number of clicks to use multimask output in SAM.\n multimask_output_for_tracking (bool): Whether to use multimask output for tracking.\n use_multimask_token_for_obj_ptr (bool): Whether to use multimask tokens for object pointers.\n iou_prediction_use_sigmoid (bool): Whether to use sigmoid to restrict IoU prediction to [0-1].\n memory_temporal_stride_for_eval (int): Memory bank's temporal stride during evaluation.\n non_overlap_masks_for_mem_enc (bool): Whether to apply non-overlapping constraints on object masks in\n memory encoder during evaluation.\n use_obj_ptrs_in_encoder (bool): Whether to cross-attend to object pointers from other frames in the encoder.\n max_obj_ptrs_in_encoder (int): Maximum number of object pointers from other frames in encoder\n cross-attention.\n add_tpos_enc_to_obj_ptrs (bool): Whether to add temporal positional encoding to object pointers in\n the encoder.\n proj_tpos_enc_in_obj_ptrs (bool): Whether to add an extra linear projection layer for temporal positional\n encoding in object pointers.\n use_signed_tpos_enc_to_obj_ptrs (bool): Whether to use signed distance in the temporal positional encoding\n in the object pointers.\n only_obj_ptrs_in_the_past_for_eval (bool): Whether to only attend to object pointers in the past\n during evaluation.\n pred_obj_scores (bool): Whether to predict if there is an object in the frame.\n pred_obj_scores_mlp (bool): Whether to use an MLP to predict object scores.\n fixed_no_obj_ptr (bool): Whether to have a fixed no-object pointer when there is no object present.\n soft_no_obj_ptr (bool): Whether to mix in no-object pointer softly for easier recovery and error mitigation.\n use_mlp_for_obj_ptr_proj (bool): Whether to use MLP for object pointer projection.\n no_obj_embed_spatial (bool): Whether add no obj embedding to spatial frames.\n sam_mask_decoder_extra_args (dict | None): Extra arguments for constructing the SAM mask decoder.\n compile_image_encoder (bool): Whether to compile the image encoder for faster inference.\n\n Examples:\n >>> image_encoder = ImageEncoderViT(...)\n >>> memory_attention = SAM2TwoWayTransformer(...)\n >>> memory_encoder = nn.Sequential(...)\n >>> model = SAM2Model(image_encoder, memory_attention, memory_encoder)\n >>> image_batch = torch.rand(1, 3, 512, 512)\n >>> features = model.forward_image(image_batch)\n >>> track_results = model.track_step(0, True, features, None, None, None, {})\n \"\"\"\n super().__init__()\n\n # Part 1: the image backbone\n self.image_encoder = image_encoder\n # Use level 0, 1, 2 for high-res setting, or just level 2 for the default setting\n self.use_high_res_features_in_sam = use_high_res_features_in_sam\n self.num_feature_levels = 3 if use_high_res_features_in_sam else 1\n self.use_obj_ptrs_in_encoder = use_obj_ptrs_in_encoder\n self.max_obj_ptrs_in_encoder = max_obj_ptrs_in_encoder\n if use_obj_ptrs_in_encoder:\n # A conv layer to downsample the mask prompt to stride 4 (the same stride as\n # low-res SAM mask logits) and to change its scales from 0~1 to SAM logit scale,\n # so that it can be fed into the SAM mask decoder to generate a pointer.\n self.mask_downsample = torch.nn.Conv2d(1, 1, kernel_size=4, stride=4)\n self.add_tpos_enc_to_obj_ptrs = add_tpos_enc_to_obj_ptrs\n if proj_tpos_enc_in_obj_ptrs:\n assert add_tpos_enc_to_obj_ptrs # these options need to be used together\n self.proj_tpos_enc_in_obj_ptrs = proj_tpos_enc_in_obj_ptrs\n self.use_signed_tpos_enc_to_obj_ptrs = use_signed_tpos_enc_to_obj_ptrs\n self.only_obj_ptrs_in_the_past_for_eval = only_obj_ptrs_in_the_past_for_eval\n\n # Part 2: memory attention to condition current frame's visual features\n # with memories (and obj ptrs) from past frames\n self.memory_attention = memory_attention\n self.hidden_dim = memory_attention.d_model\n\n # Part 3: memory encoder for the previous frame's outputs\n self.memory_encoder = memory_encoder\n self.mem_dim = self.hidden_dim\n if hasattr(self.memory_encoder, \"out_proj\") and hasattr(self.memory_encoder.out_proj, \"weight\"):\n # if there is compression of memories along channel dim\n self.mem_dim = self.memory_encoder.out_proj.weight.shape[0]\n self.num_maskmem = num_maskmem # Number of memories accessible\n # Temporal encoding of the memories\n self.maskmem_tpos_enc = torch.nn.Parameter(torch.zeros(num_maskmem, 1, 1, self.mem_dim))\n trunc_normal_(self.maskmem_tpos_enc, std=0.02)\n # a single token to indicate no memory embedding from previous frames\n self.no_mem_embed = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))\n self.no_mem_pos_enc = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))\n trunc_normal_(self.no_mem_embed, std=0.02)\n trunc_normal_(self.no_mem_pos_enc, std=0.02)\n self.directly_add_no_mem_embed = directly_add_no_mem_embed\n # Apply sigmoid to the output raw mask logits (to turn them from\n # range (-inf, +inf) to range (0, 1)) before feeding them into the memory encoder\n self.sigmoid_scale_for_mem_enc = sigmoid_scale_for_mem_enc\n self.sigmoid_bias_for_mem_enc = sigmoid_bias_for_mem_enc\n self.binarize_mask_from_pts_for_mem_enc = binarize_mask_from_pts_for_mem_enc\n self.non_overlap_masks_for_mem_enc = non_overlap_masks_for_mem_enc\n self.memory_temporal_stride_for_eval = memory_temporal_stride_for_eval\n # On frames with mask input, whether to directly output the input mask without\n # using a SAM prompt encoder + mask decoder\n self.use_mask_input_as_output_without_sam = use_mask_input_as_output_without_sam\n self.multimask_output_in_sam = multimask_output_in_sam\n self.multimask_min_pt_num = multimask_min_pt_num\n self.multimask_max_pt_num = multimask_max_pt_num\n self.multimask_output_for_tracking = multimask_output_for_tracking\n self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr\n self.iou_prediction_use_sigmoid = iou_prediction_use_sigmoid\n\n # Part 4: SAM-style prompt encoder (for both mask and point inputs)\n # and SAM-style mask decoder for the final mask output\n self.image_size = image_size\n self.backbone_stride = backbone_stride\n self.sam_mask_decoder_extra_args = sam_mask_decoder_extra_args\n self.pred_obj_scores = pred_obj_scores\n self.pred_obj_scores_mlp = pred_obj_scores_mlp\n self.fixed_no_obj_ptr = fixed_no_obj_ptr\n self.soft_no_obj_ptr = soft_no_obj_ptr\n if self.fixed_no_obj_ptr:\n assert self.pred_obj_scores\n assert self.use_obj_ptrs_in_encoder\n if self.pred_obj_scores and self.use_obj_ptrs_in_encoder:\n self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim))\n trunc_normal_(self.no_obj_ptr, std=0.02)\n self.use_mlp_for_obj_ptr_proj = use_mlp_for_obj_ptr_proj\n self.no_obj_embed_spatial = None\n if no_obj_embed_spatial:\n self.no_obj_embed_spatial = torch.nn.Parameter(torch.zeros(1, self.mem_dim))\n trunc_normal_(self.no_obj_embed_spatial, std=0.02)\n\n self._build_sam_heads()\n self.max_cond_frames_in_attn = max_cond_frames_in_attn\n\n # Model compilation\n if compile_image_encoder:\n # Compile the forward function (not the full module) to allow loading checkpoints.\n LOGGER.info(\"Image encoder compilation is enabled. First forward pass will be slow.\")\n self.image_encoder.forward = torch.compile(\n self.image_encoder.forward,\n mode=\"max-autotune\",\n fullgraph=True,\n dynamic=False,\n )\n\n @property\n def device(self):\n \"\"\"Return the device on which the model's parameters are stored.\"\"\"\n return next(self.parameters()).device\n\n def forward(self, *args, **kwargs):\n \"\"\"Process image and prompt inputs to generate object masks and scores in video sequences.\"\"\"\n raise NotImplementedError(\n \"Please use the corresponding methods in SAM2VideoPredictor for inference.\"\n \"See notebooks/video_predictor_example.ipynb for an example.\"\n )\n\n def _build_sam_heads(self):\n \"\"\"Build SAM-style prompt encoder and mask decoder for image segmentation tasks.\"\"\"\n self.sam_prompt_embed_dim = self.hidden_dim\n self.sam_image_embedding_size = self.image_size // self.backbone_stride\n\n # Build PromptEncoder and MaskDecoder from SAM (hyperparameters like `mask_in_chans=16` are from SAM code)\n self.sam_prompt_encoder = PromptEncoder(\n embed_dim=self.sam_prompt_embed_dim,\n image_embedding_size=(\n self.sam_image_embedding_size,\n self.sam_image_embedding_size,\n ),\n input_image_size=(self.image_size, self.image_size),\n mask_in_chans=16,\n )\n self.sam_mask_decoder = SAM2MaskDecoder(\n num_multimask_outputs=3,\n transformer=SAM2TwoWayTransformer(\n depth=2,\n embedding_dim=self.sam_prompt_embed_dim,\n mlp_dim=2048,\n num_heads=8,\n ),\n transformer_dim=self.sam_prompt_embed_dim,\n iou_head_depth=3,\n iou_head_hidden_dim=256,\n use_high_res_features=self.use_high_res_features_in_sam,\n iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid,\n pred_obj_scores=self.pred_obj_scores,\n pred_obj_scores_mlp=self.pred_obj_scores_mlp,\n use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr,\n **(self.sam_mask_decoder_extra_args or {}),\n )\n if self.use_obj_ptrs_in_encoder:\n # a linear projection on SAM output tokens to turn them into object pointers\n self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim)\n if self.use_mlp_for_obj_ptr_proj:\n self.obj_ptr_proj = MLP(self.hidden_dim, self.hidden_dim, self.hidden_dim, 3)\n else:\n self.obj_ptr_proj = torch.nn.Identity()\n if self.proj_tpos_enc_in_obj_ptrs:\n # a linear projection on temporal positional encoding in object pointers to\n # avoid potential interference with spatial positional encoding\n self.obj_ptr_tpos_proj = torch.nn.Linear(self.hidden_dim, self.mem_dim)\n else:\n self.obj_ptr_tpos_proj = torch.nn.Identity()\n\n def _forward_sam_heads(\n self,\n backbone_features,\n point_inputs=None,\n mask_inputs=None,\n high_res_features=None,\n multimask_output=False,\n ):\n \"\"\"\n Forward pass through SAM prompt encoders and mask heads.\n\n This method processes image features and optional point/mask inputs to generate object masks and scores.\n\n Args:\n backbone_features (torch.Tensor): Image features with shape (B, C, H, W).\n point_inputs (Dict[str, torch.Tensor] | None): Dictionary containing point prompts.\n 'point_coords': Tensor of shape (B, P, 2) with float32 dtype, containing absolute\n pixel-unit coordinates in (x, y) format for P input points.\n 'point_labels': Tensor of shape (B, P) with int32 dtype, where 1 means positive clicks,\n 0 means negative clicks, and -1 means padding.\n mask_inputs (torch.Tensor | None): Mask of shape (B, 1, H*16, W*16), float or bool, with the\n same spatial size as the image.\n high_res_features (List[torch.Tensor] | None): List of two feature maps with shapes\n (B, C, 4*H, 4*W) and (B, C, 2*H, 2*W) respectively, used as high-resolution feature maps\n for SAM decoder.\n multimask_output (bool): If True, output 3 candidate masks and their IoU estimates; if False,\n output only 1 mask and its IoU estimate.\n\n Returns:\n low_res_multimasks (torch.Tensor): Tensor of shape (B, M, H*4, W*4) with SAM output mask logits.\n high_res_multimasks (torch.Tensor): Tensor of shape (B, M, H*16, W*16) with upsampled mask logits.\n ious (torch.Tensor): Tensor of shape (B, M) with estimated IoU for each output mask.\n low_res_masks (torch.Tensor): Tensor of shape (B, 1, H*4, W*4) with the best low-resolution mask.\n high_res_masks (torch.Tensor): Tensor of shape (B, 1, H*16, W*16) with the best high-resolution mask.\n obj_ptr (torch.Tensor): Tensor of shape (B, C) with object pointer vector for the output mask.\n object_score_logits (torch.Tensor): Tensor of shape (B) with object score logits.\n\n Examples:\n >>> backbone_features = torch.rand(1, 256, 32, 32)\n >>> point_inputs = {\"point_coords\": torch.rand(1, 2, 2), \"point_labels\": torch.tensor([[1, 0]])}\n >>> mask_inputs = torch.rand(1, 1, 512, 512)\n >>> results = model._forward_sam_heads(backbone_features, point_inputs, mask_inputs)\n >>> (\n ... low_res_multimasks,\n ... high_res_multimasks,\n ... ious,\n ... low_res_masks,\n ... high_res_masks,\n ... obj_ptr,\n ... object_score_logits,\n ... ) = results\n \"\"\"\n B = backbone_features.size(0)\n device = backbone_features.device\n assert backbone_features.size(1) == self.sam_prompt_embed_dim\n assert backbone_features.size(2) == self.sam_image_embedding_size\n assert backbone_features.size(3) == self.sam_image_embedding_size\n\n # a) Handle point prompts\n if point_inputs is not None:\n sam_point_coords = point_inputs[\"point_coords\"]\n sam_point_labels = point_inputs[\"point_labels\"]\n assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B\n else:\n # If no points are provide, pad with an empty point (with label -1)\n sam_point_coords = torch.zeros(B, 1, 2, device=device)\n sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device)\n\n # b) Handle mask prompts\n if mask_inputs is not None:\n # If mask_inputs is provided, downsize it into low-res mask input if needed\n # and feed it as a dense mask prompt into the SAM mask encoder\n assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1)\n if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size:\n sam_mask_prompt = F.interpolate(\n mask_inputs.float(),\n size=self.sam_prompt_encoder.mask_input_size,\n align_corners=False,\n mode=\"bilinear\",\n antialias=True, # use antialias for downsampling\n )\n else:\n sam_mask_prompt = mask_inputs\n else:\n # Otherwise, simply feed None (and SAM's prompt encoder will add\n # a learned `no_mask_embed` to indicate no mask input in this case).\n sam_mask_prompt = None\n\n sparse_embeddings, dense_embeddings = self.sam_prompt_encoder(\n points=(sam_point_coords, sam_point_labels),\n boxes=None,\n masks=sam_mask_prompt,\n )\n low_res_multimasks, ious, sam_output_tokens, object_score_logits = self.sam_mask_decoder(\n image_embeddings=backbone_features,\n image_pe=self.sam_prompt_encoder.get_dense_pe(),\n sparse_prompt_embeddings=sparse_embeddings,\n dense_prompt_embeddings=dense_embeddings,\n multimask_output=multimask_output,\n repeat_image=False, # the image is already batched\n high_res_features=high_res_features,\n )\n if self.pred_obj_scores:\n is_obj_appearing = object_score_logits > 0\n\n # Spatial memory mask is a *hard* choice between obj and no obj, consistent with actual mask prediction\n low_res_multimasks = torch.where(is_obj_appearing[:, None, None], low_res_multimasks, NO_OBJ_SCORE)\n\n # convert masks from possibly bfloat16 (or float16) to float32\n # (older PyTorch versions before 2.1 don't support `interpolate` on bf16)\n low_res_multimasks = low_res_multimasks.float()\n high_res_multimasks = F.interpolate(\n low_res_multimasks,\n size=(self.image_size, self.image_size),\n mode=\"bilinear\",\n align_corners=False,\n )\n\n sam_output_token = sam_output_tokens[:, 0]\n if multimask_output:\n # take the best mask prediction (with the highest IoU estimation)\n best_iou_inds = torch.argmax(ious, dim=-1)\n batch_inds = torch.arange(B, device=device)\n low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)\n high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)\n if sam_output_tokens.size(1) > 1:\n sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]\n else:\n low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks\n\n # Extract object pointer from the SAM output token (with occlusion handling)\n obj_ptr = self.obj_ptr_proj(sam_output_token)\n if self.pred_obj_scores:\n # Allow *soft* no obj ptr, unlike for masks\n if self.soft_no_obj_ptr:\n lambda_is_obj_appearing = object_score_logits.sigmoid()\n else:\n lambda_is_obj_appearing = is_obj_appearing.float()\n\n if self.fixed_no_obj_ptr:\n obj_ptr = lambda_is_obj_appearing * obj_ptr\n obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr\n\n return (\n low_res_multimasks,\n high_res_multimasks,\n ious,\n low_res_masks,\n high_res_masks,\n obj_ptr,\n object_score_logits,\n )\n\n def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs):\n \"\"\"Process mask inputs directly as output, bypassing SAM encoder/decoder.\"\"\"\n # Use -10/+10 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid).\n out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05\n mask_inputs_float = mask_inputs.float()\n high_res_masks = mask_inputs_float * out_scale + out_bias\n low_res_masks = F.interpolate(\n high_res_masks,\n size=(high_res_masks.size(-2) // 4, high_res_masks.size(-1) // 4),\n align_corners=False,\n mode=\"bilinear\",\n antialias=True, # use antialias for downsampling\n )\n # a dummy IoU prediction of all 1's under mask input\n ious = mask_inputs.new_ones(mask_inputs.size(0), 1).float()\n if not self.use_obj_ptrs_in_encoder:\n # all zeros as a dummy object pointer (of shape [B, C])\n obj_ptr = torch.zeros(mask_inputs.size(0), self.hidden_dim, device=mask_inputs.device)\n else:\n # produce an object pointer using the SAM decoder from the mask input\n _, _, _, _, _, obj_ptr, _ = self._forward_sam_heads(\n backbone_features=backbone_features,\n mask_inputs=self.mask_downsample(mask_inputs_float),\n high_res_features=high_res_features,\n )\n # In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem;\n # Below, we follow the same design axiom to use mask_input to decide if obj appears or not instead of relying\n # on the object_scores from the SAM decoder.\n is_obj_appearing = torch.any(mask_inputs.flatten(1).float() > 0.0, dim=1)\n is_obj_appearing = is_obj_appearing[..., None]\n lambda_is_obj_appearing = is_obj_appearing.float()\n object_score_logits = out_scale * lambda_is_obj_appearing + out_bias\n if self.pred_obj_scores:\n if self.fixed_no_obj_ptr:\n obj_ptr = lambda_is_obj_appearing * obj_ptr\n obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr\n\n return (\n low_res_masks,\n high_res_masks,\n ious,\n low_res_masks,\n high_res_masks,\n obj_ptr,\n object_score_logits,\n )\n\n def forward_image(self, img_batch: torch.Tensor):\n \"\"\"Process image batch through encoder to extract multi-level features for SAM model.\"\"\"\n backbone_out = self.image_encoder(img_batch)\n if self.use_high_res_features_in_sam:\n # precompute projected level 0 and level 1 features in SAM decoder\n # to avoid running it again on every SAM click\n backbone_out[\"backbone_fpn\"][0] = self.sam_mask_decoder.conv_s0(backbone_out[\"backbone_fpn\"][0])\n backbone_out[\"backbone_fpn\"][1] = self.sam_mask_decoder.conv_s1(backbone_out[\"backbone_fpn\"][1])\n return backbone_out\n\n def _prepare_backbone_features(self, backbone_out):\n \"\"\"Prepare and flatten visual features from the image backbone output for further processing.\"\"\"\n assert len(backbone_out[\"backbone_fpn\"]) == len(backbone_out[\"vision_pos_enc\"])\n assert len(backbone_out[\"backbone_fpn\"]) >= self.num_feature_levels\n\n feature_maps = backbone_out[\"backbone_fpn\"][-self.num_feature_levels :]\n vision_pos_embeds = backbone_out[\"vision_pos_enc\"][-self.num_feature_levels :]\n\n feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds]\n # flatten NxCxHxW to HWxNxC\n vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps]\n vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds]\n\n return backbone_out, vision_feats, vision_pos_embeds, feat_sizes\n\n def _prepare_memory_conditioned_features(\n self,\n frame_idx,\n is_init_cond_frame,\n current_vision_feats,\n current_vision_pos_embeds,\n feat_sizes,\n output_dict,\n num_frames,\n track_in_reverse=False, # tracking in reverse time order (for demo usage)\n ):\n \"\"\"Prepare memory-conditioned features by fusing current frame's visual features with previous memories.\"\"\"\n B = current_vision_feats[-1].size(1) # batch size on this frame\n C = self.hidden_dim\n H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size\n device = current_vision_feats[-1].device\n # The case of `self.num_maskmem == 0` below is primarily used for reproducing SAM on images.\n # In this case, we skip the fusion with any memory.\n if self.num_maskmem == 0: # Disable memory and skip fusion\n return current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)\n num_obj_ptr_tokens = 0\n tpos_sign_mul = -1 if track_in_reverse else 1\n # Step 1: condition the visual features of the current frame on previous memories\n if not is_init_cond_frame:\n # Retrieve the memories encoded with the maskmem backbone\n to_cat_memory, to_cat_memory_pos_embed = [], []\n # Add conditioning frame's output first (all cond frames have t_pos=0 for\n # when getting temporal positional embedding below)\n assert len(output_dict[\"cond_frame_outputs\"]) > 0\n # Select a maximum number of temporally closest cond frames for cross attention\n cond_outputs = output_dict[\"cond_frame_outputs\"]\n selected_cond_outputs, unselected_cond_outputs = select_closest_cond_frames(\n frame_idx, cond_outputs, self.max_cond_frames_in_attn\n )\n t_pos_and_prevs = [(0, out) for out in selected_cond_outputs.values()]\n # Add last (self.num_maskmem - 1) frames before current frame for non-conditioning memory\n # the earliest one has t_pos=1 and the latest one has t_pos=self.num_maskmem-1\n # We also allow taking the memory frame non-consecutively (with r>1), in which case\n # we take (self.num_maskmem - 2) frames among every r-th frames plus the last frame.\n r = 1 if self.training else self.memory_temporal_stride_for_eval\n for t_pos in range(1, self.num_maskmem):\n t_rel = self.num_maskmem - t_pos # how many frames before current frame\n if t_rel == 1:\n # for t_rel == 1, we take the last frame (regardless of r)\n prev_frame_idx = frame_idx + t_rel if track_in_reverse else frame_idx - t_rel\n elif not track_in_reverse:\n # first find the nearest frame among every r-th frames before this frame\n # for r=1, this would be (frame_idx - 2)\n prev_frame_idx = ((frame_idx - 2) // r) * r\n # then seek further among every r-th frames\n prev_frame_idx = prev_frame_idx - (t_rel - 2) * r\n else:\n # first find the nearest frame among every r-th frames after this frame\n # for r=1, this would be (frame_idx + 2)\n prev_frame_idx = -(-(frame_idx + 2) // r) * r\n # then seek further among every r-th frames\n prev_frame_idx = prev_frame_idx + (t_rel - 2) * r\n out = output_dict[\"non_cond_frame_outputs\"].get(prev_frame_idx, None)\n if out is None:\n # If an unselected conditioning frame is among the last (self.num_maskmem - 1)\n # frames, we still attend to it as if it's a non-conditioning frame.\n out = unselected_cond_outputs.get(prev_frame_idx, None)\n t_pos_and_prevs.append((t_pos, out))\n\n for t_pos, prev in t_pos_and_prevs:\n if prev is None:\n continue # skip padding frames\n # \"maskmem_features\" might have been offloaded to CPU in demo use cases,\n # so we load it back to inference device (it's a no-op if it's already on device).\n feats = prev[\"maskmem_features\"].to(device=device, non_blocking=True)\n to_cat_memory.append(feats.flatten(2).permute(2, 0, 1))\n # Spatial positional encoding (it might have been offloaded to CPU in eval)\n maskmem_enc = prev[\"maskmem_pos_enc\"][-1].to(device=device)\n maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1)\n # Temporal positional encoding\n maskmem_enc = maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1]\n to_cat_memory_pos_embed.append(maskmem_enc)\n\n # Construct the list of past object pointers\n if self.use_obj_ptrs_in_encoder:\n max_obj_ptrs_in_encoder = min(num_frames, self.max_obj_ptrs_in_encoder)\n # First add those object pointers from selected conditioning frames\n # (optionally, only include object pointers in the past during evaluation)\n if not self.training and self.only_obj_ptrs_in_the_past_for_eval:\n ptr_cond_outputs = {\n t: out\n for t, out in selected_cond_outputs.items()\n if (t >= frame_idx if track_in_reverse else t <= frame_idx)\n }\n else:\n ptr_cond_outputs = selected_cond_outputs\n pos_and_ptrs = [\n # Temporal pos encoding contains how far away each pointer is from current frame\n (\n (\n (frame_idx - t) * tpos_sign_mul\n if self.use_signed_tpos_enc_to_obj_ptrs\n else abs(frame_idx - t)\n ),\n out[\"obj_ptr\"],\n )\n for t, out in ptr_cond_outputs.items()\n ]\n # Add up to (max_obj_ptrs_in_encoder - 1) non-conditioning frames before current frame\n for t_diff in range(1, max_obj_ptrs_in_encoder):\n t = frame_idx + t_diff if track_in_reverse else frame_idx - t_diff\n if t < 0 or (num_frames is not None and t >= num_frames):\n break\n out = output_dict[\"non_cond_frame_outputs\"].get(t, unselected_cond_outputs.get(t, None))\n if out is not None:\n pos_and_ptrs.append((t_diff, out[\"obj_ptr\"]))\n # If we have at least one object pointer, add them to the across attention\n if pos_and_ptrs:\n pos_list, ptrs_list = zip(*pos_and_ptrs)\n # stack object pointers along dim=0 into [ptr_seq_len, B, C] shape\n obj_ptrs = torch.stack(ptrs_list, dim=0)\n # a temporal positional embedding based on how far each object pointer is from\n # the current frame (sine embedding normalized by the max pointer num).\n if self.add_tpos_enc_to_obj_ptrs:\n t_diff_max = max_obj_ptrs_in_encoder - 1\n tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim\n obj_pos = torch.tensor(pos_list, device=device)\n obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim)\n obj_pos = self.obj_ptr_tpos_proj(obj_pos)\n obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim)\n else:\n obj_pos = obj_ptrs.new_zeros(len(pos_list), B, self.mem_dim)\n if self.mem_dim < C:\n # split a pointer into (C // self.mem_dim) tokens for self.mem_dim < C\n obj_ptrs = obj_ptrs.reshape(-1, B, C // self.mem_dim, self.mem_dim)\n obj_ptrs = obj_ptrs.permute(0, 2, 1, 3).flatten(0, 1)\n obj_pos = obj_pos.repeat_interleave(C // self.mem_dim, dim=0)\n to_cat_memory.append(obj_ptrs)\n to_cat_memory_pos_embed.append(obj_pos)\n num_obj_ptr_tokens = obj_ptrs.shape[0]\n else:\n num_obj_ptr_tokens = 0\n else:\n # for initial conditioning frames, encode them without using any previous memory\n if self.directly_add_no_mem_embed:\n # directly add no-mem embedding (instead of using the transformer encoder)\n pix_feat_with_mem = current_vision_feats[-1] + self.no_mem_embed\n pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)\n return pix_feat_with_mem\n\n # Use a dummy token on the first frame (to avoid empty memory input to transformer encoder)\n to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)]\n to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)]\n\n # Step 2: Concatenate the memories and forward through the transformer encoder\n memory = torch.cat(to_cat_memory, dim=0)\n memory_pos_embed = torch.cat(to_cat_memory_pos_embed, dim=0)\n\n pix_feat_with_mem = self.memory_attention(\n curr=current_vision_feats,\n curr_pos=current_vision_pos_embeds,\n memory=memory,\n memory_pos=memory_pos_embed,\n num_obj_ptr_tokens=num_obj_ptr_tokens,\n )\n # reshape the output (HW)BC => BCHW\n pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)\n return pix_feat_with_mem\n\n def _encode_new_memory(\n self,\n current_vision_feats,\n feat_sizes,\n pred_masks_high_res,\n object_score_logits,\n is_mask_from_pts,\n ):\n \"\"\"Encode frame features and masks into a new memory representation for video segmentation.\"\"\"\n B = current_vision_feats[-1].size(1) # batch size on this frame\n C = self.hidden_dim\n H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size\n # top-level feature, (HW)BC => BCHW\n pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)\n if self.non_overlap_masks_for_mem_enc and not self.training:\n # optionally, apply non-overlapping constraints to the masks (it's applied\n # in the batch dimension and should only be used during eval, where all\n # the objects come from the same video under batch size 1).\n pred_masks_high_res = self._apply_non_overlapping_constraints(pred_masks_high_res)\n # scale the raw mask logits with a temperature before applying sigmoid\n binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts\n if binarize and not self.training:\n mask_for_mem = (pred_masks_high_res > 0).float()\n else:\n # apply sigmoid on the raw mask logits to turn them into range (0, 1)\n mask_for_mem = torch.sigmoid(pred_masks_high_res)\n # apply scale and bias terms to the sigmoid probabilities\n if self.sigmoid_scale_for_mem_enc != 1.0:\n mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc\n if self.sigmoid_bias_for_mem_enc != 0.0:\n mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc\n maskmem_out = self.memory_encoder(pix_feat, mask_for_mem, skip_mask_sigmoid=True) # sigmoid already applied\n maskmem_features = maskmem_out[\"vision_features\"]\n maskmem_pos_enc = maskmem_out[\"vision_pos_enc\"]\n # add a no-object embedding to the spatial memory to indicate that the frame\n # is predicted to be occluded (i.e. no object is appearing in the frame)\n if self.no_obj_embed_spatial is not None:\n is_obj_appearing = (object_score_logits > 0).float()\n maskmem_features += (1 - is_obj_appearing[..., None, None]) * self.no_obj_embed_spatial[\n ..., None, None\n ].expand(*maskmem_features.shape)\n\n return maskmem_features, maskmem_pos_enc\n\n def _track_step(\n self,\n frame_idx,\n is_init_cond_frame,\n current_vision_feats,\n current_vision_pos_embeds,\n feat_sizes,\n point_inputs,\n mask_inputs,\n output_dict,\n num_frames,\n track_in_reverse,\n prev_sam_mask_logits,\n ):\n \"\"\"Perform a single tracking step, updating object masks and memory features based on current frame inputs.\"\"\"\n current_out = {\"point_inputs\": point_inputs, \"mask_inputs\": mask_inputs}\n # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW\n if len(current_vision_feats) > 1:\n high_res_features = [\n x.permute(1, 2, 0).view(x.size(1), x.size(2), *s)\n for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1])\n ]\n else:\n high_res_features = None\n if mask_inputs is not None and self.use_mask_input_as_output_without_sam:\n # When use_mask_input_as_output_without_sam=True, we directly output the mask input\n # (see it as a GT mask) without using a SAM prompt encoder + mask decoder.\n pix_feat = current_vision_feats[-1].permute(1, 2, 0)\n pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])\n sam_outputs = self._use_mask_as_output(pix_feat, high_res_features, mask_inputs)\n else:\n # fused the visual feature with previous memory features in the memory bank\n pix_feat = self._prepare_memory_conditioned_features(\n frame_idx=frame_idx,\n is_init_cond_frame=is_init_cond_frame,\n current_vision_feats=current_vision_feats[-1:],\n current_vision_pos_embeds=current_vision_pos_embeds[-1:],\n feat_sizes=feat_sizes[-1:],\n output_dict=output_dict,\n num_frames=num_frames,\n track_in_reverse=track_in_reverse,\n )\n # apply SAM-style segmentation head\n # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder,\n # e.g. in demo where such logits come from earlier interaction instead of correction sampling\n # (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead)\n if prev_sam_mask_logits is not None:\n assert point_inputs is not None and mask_inputs is None\n mask_inputs = prev_sam_mask_logits\n multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)\n sam_outputs = self._forward_sam_heads(\n backbone_features=pix_feat,\n point_inputs=point_inputs,\n mask_inputs=mask_inputs,\n high_res_features=high_res_features,\n multimask_output=multimask_output,\n )\n return current_out, sam_outputs, high_res_features, pix_feat\n\n def _encode_memory_in_output(\n self,\n current_vision_feats,\n feat_sizes,\n point_inputs,\n run_mem_encoder,\n high_res_masks,\n object_score_logits,\n current_out,\n ):\n \"\"\"Run memory encoder on predicted mask to encode it into a new memory feature for future frames.\"\"\"\n if run_mem_encoder and self.num_maskmem > 0:\n high_res_masks_for_mem_enc = high_res_masks\n maskmem_features, maskmem_pos_enc = self._encode_new_memory(\n current_vision_feats=current_vision_feats,\n feat_sizes=feat_sizes,\n pred_masks_high_res=high_res_masks_for_mem_enc,\n object_score_logits=object_score_logits,\n is_mask_from_pts=(point_inputs is not None),\n )\n current_out[\"maskmem_features\"] = maskmem_features\n current_out[\"maskmem_pos_enc\"] = maskmem_pos_enc\n else:\n current_out[\"maskmem_features\"] = None\n current_out[\"maskmem_pos_enc\"] = None\n\n def track_step(\n self,\n frame_idx,\n is_init_cond_frame,\n current_vision_feats,\n current_vision_pos_embeds,\n feat_sizes,\n point_inputs,\n mask_inputs,\n output_dict,\n num_frames,\n track_in_reverse=False, # tracking in reverse time order (for demo usage)\n # Whether to run the memory encoder on the predicted masks. Sometimes we might want\n # to skip the memory encoder with `run_mem_encoder=False`. For example,\n # in demo we might call `track_step` multiple times for each user click,\n # and only encode the memory when the user finalizes their clicks. And in ablation\n # settings like SAM training on static images, we don't need the memory encoder.\n run_mem_encoder=True,\n # The previously predicted SAM mask logits (which can be fed together with new clicks in demo).\n prev_sam_mask_logits=None,\n ):\n \"\"\"Perform a single tracking step, updating object masks and memory features based on current frame inputs.\"\"\"\n current_out, sam_outputs, _, _ = self._track_step(\n frame_idx,\n is_init_cond_frame,\n current_vision_feats,\n current_vision_pos_embeds,\n feat_sizes,\n point_inputs,\n mask_inputs,\n output_dict,\n num_frames,\n track_in_reverse,\n prev_sam_mask_logits,\n )\n _, _, _, low_res_masks, high_res_masks, obj_ptr, object_score_logits = sam_outputs\n\n current_out[\"pred_masks\"] = low_res_masks\n current_out[\"pred_masks_high_res\"] = high_res_masks\n current_out[\"obj_ptr\"] = obj_ptr\n if not self.training:\n # Only add this in inference (to avoid unused param in activation checkpointing;\n # it's mainly used in the demo to encode spatial memories w/ consolidated masks)\n current_out[\"object_score_logits\"] = object_score_logits\n\n # Run memory encoder on the predicted mask to encode it into a new memory feature (for use in future frames)\n self._encode_memory_in_output(\n current_vision_feats,\n feat_sizes,\n point_inputs,\n run_mem_encoder,\n high_res_masks,\n object_score_logits,\n current_out,\n )\n\n return current_out\n\n def _use_multimask(self, is_init_cond_frame, point_inputs):\n \"\"\"Determine whether to use multiple mask outputs in the SAM head based on configuration and inputs.\"\"\"\n num_pts = 0 if point_inputs is None else point_inputs[\"point_labels\"].size(1)\n return (\n self.multimask_output_in_sam\n and (is_init_cond_frame or self.multimask_output_for_tracking)\n and (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num)\n )\n\n @staticmethod\n def _apply_non_overlapping_constraints(pred_masks):\n \"\"\"Apply non-overlapping constraints to masks, keeping the highest scoring object per location.\"\"\"\n batch_size = pred_masks.size(0)\n if batch_size == 1:\n return pred_masks\n\n device = pred_masks.device\n # \"max_obj_inds\": object index of the object with the highest score at each location\n max_obj_inds = torch.argmax(pred_masks, dim=0, keepdim=True)\n # \"batch_obj_inds\": object index of each object slice (along dim 0) in `pred_masks`\n batch_obj_inds = torch.arange(batch_size, device=device)[:, None, None, None]\n keep = max_obj_inds == batch_obj_inds\n # suppress overlapping regions' scores below -10.0 so that the foreground regions\n # don't overlap (here sigmoid(-10.0)=4.5398e-05)\n pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0))\n return pred_masks\n\n def set_binarize(self, binarize=False):\n \"\"\"Set binarize for VideoPredictor.\"\"\"\n self.binarize_mask_from_pts_for_mem_enc = binarize\n\n def set_imgsz(self, imgsz):\n \"\"\"Set image size to make model compatible with different image sizes.\"\"\"\n self.image_size = imgsz[0]\n self.sam_prompt_encoder.input_image_size = imgsz\n self.sam_prompt_encoder.image_embedding_size = [x // 16 for x in imgsz] # fixed ViT patch size of 16", "chunk_type": "class", "name": "SAM2Model", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\sam.py", "start_line": 103, "end_line": 1037, "start_col": 0, "end_col": 79, "parent_name": null, "docstring": "SAM2Model class for Segment Anything Model 2 with memory-based video object segmentation capabilities.\n\nThis class extends the functionality of SAM to handle video sequences, incorporating memory mechanisms\nfor temporal consistency and efficient tracking of objects across frames.\n\nAttributes:\n mask_threshold (float): Threshold value for mask prediction.\n image_encoder (ImageEncoderViT): Visual encoder for extracting image features.\n memory_attention (nn.Module): Module for attending to memory features.\n memory_encoder (nn.Module): Encoder for generating memory representations.\n num_maskmem (int): Number of accessible memory frames.\n image_size (int): Size of input images.\n backbone_stride (int): Stride of the backbone network output.\n sam_prompt_embed_dim (int): Dimension of SAM prompt embeddings.\n sam_image_embedding_size (int): Size of SAM image embeddings.\n sam_prompt_encoder (PromptEncoder): Encoder for processing input prompts.\n sam_mask_decoder (SAM2MaskDecoder): Decoder for generating object masks.\n obj_ptr_proj (nn.Module): Projection layer for object pointers.\n obj_ptr_tpos_proj (nn.Module): Projection for temporal positional encoding in object pointers.\n hidden_dim (int): Hidden dimension of the model.\n mem_dim (int): Memory dimension for encoding features.\n use_high_res_features_in_sam (bool): Whether to use high-resolution feature maps in the SAM mask decoder.\n use_obj_ptrs_in_encoder (bool): Whether to cross-attend to object pointers from other frames in the encoder.\n max_obj_ptrs_in_encoder (int): Maximum number of object pointers from other frames in encoder cross-attention.\n add_tpos_enc_to_obj_ptrs (bool): Whether to add temporal positional encoding to object pointers.\n proj_tpos_enc_in_obj_ptrs (bool): Whether to add an extra linear projection layer for temporal positional\n encoding in object pointers.\n use_signed_tpos_enc_to_obj_ptrs (bool): Whether to use signed distance in temporal positional encoding.\n only_obj_ptrs_in_the_past_for_eval (bool): Whether to only attend to object pointers in the past during\n evaluation.\n pred_obj_scores (bool): Whether to predict if there is an object in the frame.\n pred_obj_scores_mlp (bool): Whether to use an MLP to predict object scores.\n fixed_no_obj_ptr (bool): Whether to have a fixed no-object pointer when there is no object present.\n soft_no_obj_ptr (bool): Whether to mix in no-object pointer softly for easier recovery and error mitigation.\n use_mlp_for_obj_ptr_proj (bool): Whether to use MLP for object pointer projection.\n no_obj_embed_spatial (torch.Tensor | None): No-object embedding for spatial frames.\n max_cond_frames_in_attn (int): Maximum number of conditioning frames to participate in memory attention.\n directly_add_no_mem_embed (bool): Whether to directly add no-memory embedding to image feature on the\n first frame.\n multimask_output_in_sam (bool): Whether to output multiple masks for the first click on initial\n conditioning frames.\n multimask_min_pt_num (int): Minimum number of clicks to use multimask output in SAM.\n multimask_max_pt_num (int): Maximum number of clicks to use multimask output in SAM.\n multimask_output_for_tracking (bool): Whether to use multimask output for tracking.\n use_multimask_token_for_obj_ptr (bool): Whether to use multimask tokens for object pointers.\n iou_prediction_use_sigmoid (bool): Whether to use sigmoid to restrict IoU prediction to [0-1].\n memory_temporal_stride_for_eval (int): Memory bank's temporal stride during evaluation.\n non_overlap_masks_for_mem_enc (bool): Whether to apply non-overlapping constraints on object masks in\n memory encoder during evaluation.\n sigmoid_scale_for_mem_enc (float): Scale factor for mask sigmoid probability.\n sigmoid_bias_for_mem_enc (float): Bias factor for mask sigmoid probability.\n binarize_mask_from_pts_for_mem_enc (bool): Whether to binarize sigmoid mask logits on interacted frames\n with clicks during evaluation.\n use_mask_input_as_output_without_sam (bool): Whether to directly output the input mask without using SAM\n prompt encoder and mask decoder on frames with mask input.\n\nMethods:\n forward_image: Process image batch through encoder to extract multi-level features.\n track_step: Perform a single tracking step, updating object masks and memory features.\n set_binarize: Set binarize for VideoPredictor.\n set_imgsz: Set image size to make model compatible with different image sizes.\n\nExamples:\n >>> model = SAM2Model(image_encoder, memory_attention, memory_encoder)\n >>> image_batch = torch.rand(1, 3, 512, 512)\n >>> features = model.forward_image(image_batch)\n >>> track_results = model.track_step(0, True, features, None, None, None, {})", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.List", "torch", "torch.nn.functional", "torch.nn", "torch.nn.init.trunc_normal_", "ultralytics.nn.modules.MLP", "ultralytics.utils.LOGGER", "blocks.SAM2TwoWayTransformer", "decoders.MaskDecoder", "decoders.SAM2MaskDecoder", "encoders.ImageEncoderViT", "encoders.PromptEncoder", "utils.get_1d_sine_pe", "utils.select_closest_cond_frames", "torch.nn.Module" ], "chunk_id": "class_SAM2Model_8b658922" }, { "content": "import itertools", "chunk_type": "import", "name": "itertools", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\tiny_encoder.py", "start_line": 12, "end_line": 12, "start_col": 0, "end_col": 16, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_itertools_c0c04d63" }, { "content": "from typing import List, Optional, Tuple, Union", "chunk_type": "import", "name": "List, Optional, Tuple, Union", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\tiny_encoder.py", "start_line": 13, "end_line": 13, "start_col": 0, "end_col": 47, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_List, Optional, Tuple, Union_50b5127e" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\tiny_encoder.py", "start_line": 15, "end_line": 15, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_dd9d1b04" }, { "content": "import torch.nn as nn", "chunk_type": "import", "name": "torch.nn", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\tiny_encoder.py", "start_line": 16, "end_line": 16, "start_col": 0, "end_col": 21, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch.nn_76f46e0d" }, { "content": "import torch.nn.functional as F", "chunk_type": "import", "name": "torch.nn.functional", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\tiny_encoder.py", "start_line": 17, "end_line": 17, "start_col": 0, "end_col": 31, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch.nn.functional_9776ea9d" }, { "content": "from ultralytics.nn.modules import LayerNorm2d", "chunk_type": "import", "name": "LayerNorm2d", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\tiny_encoder.py", "start_line": 19, "end_line": 19, "start_col": 0, "end_col": 46, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_LayerNorm2d_ec2dd36e" }, { "content": "from ultralytics.utils.instance import to_2tuple", "chunk_type": "import", "name": "to_2tuple", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\tiny_encoder.py", "start_line": 20, "end_line": 20, "start_col": 0, "end_col": 48, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_to_2tuple_dc97ab81" }, { "content": "class Conv2d_BN(torch.nn.Sequential):\n \"\"\"\n A sequential container that performs 2D convolution followed by batch normalization.\n\n This module combines a 2D convolution layer with batch normalization, providing a common building block\n for convolutional neural networks. The batch normalization weights and biases are initialized to specific\n values for optimal training performance.\n\n Attributes:\n c (torch.nn.Conv2d): 2D convolution layer.\n bn (torch.nn.BatchNorm2d): Batch normalization layer.\n\n Examples:\n >>> conv_bn = Conv2d_BN(3, 64, ks=3, stride=1, pad=1)\n >>> input_tensor = torch.randn(1, 3, 224, 224)\n >>> output = conv_bn(input_tensor)\n >>> print(output.shape)\n torch.Size([1, 64, 224, 224])\n \"\"\"\n\n def __init__(\n self,\n a: int,\n b: int,\n ks: int = 1,\n stride: int = 1,\n pad: int = 0,\n dilation: int = 1,\n groups: int = 1,\n bn_weight_init: float = 1,\n ):\n \"\"\"\n Initialize a sequential container with 2D convolution followed by batch normalization.\n\n Args:\n a (int): Number of input channels.\n b (int): Number of output channels.\n ks (int, optional): Kernel size for the convolution.\n stride (int, optional): Stride for the convolution.\n pad (int, optional): Padding for the convolution.\n dilation (int, optional): Dilation factor for the convolution.\n groups (int, optional): Number of groups for the convolution.\n bn_weight_init (float, optional): Initial value for batch normalization weight.\n \"\"\"\n super().__init__()\n self.add_module(\"c\", torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False))\n bn = torch.nn.BatchNorm2d(b)\n torch.nn.init.constant_(bn.weight, bn_weight_init)\n torch.nn.init.constant_(bn.bias, 0)\n self.add_module(\"bn\", bn)", "chunk_type": "class", "name": "Conv2d_BN", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\tiny_encoder.py", "start_line": 23, "end_line": 72, "start_col": 0, "end_col": 33, "parent_name": null, "docstring": "A sequential container that performs 2D convolution followed by batch normalization.\n\nThis module combines a 2D convolution layer with batch normalization, providing a common building block\nfor convolutional neural networks. The batch normalization weights and biases are initialized to specific\nvalues for optimal training performance.\n\nAttributes:\n c (torch.nn.Conv2d): 2D convolution layer.\n bn (torch.nn.BatchNorm2d): Batch normalization layer.\n\nExamples:\n >>> conv_bn = Conv2d_BN(3, 64, ks=3, stride=1, pad=1)\n >>> input_tensor = torch.randn(1, 3, 224, 224)\n >>> output = conv_bn(input_tensor)\n >>> print(output.shape)\n torch.Size([1, 64, 224, 224])", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "itertools", "typing.List", "typing.Optional", "typing.Tuple", "typing.Union", "torch", "torch.nn", "torch.nn.functional", "ultralytics.nn.modules.LayerNorm2d", "ultralytics.utils.instance.to_2tuple", "torch.nn.Sequential" ], "chunk_id": "class_Conv2d_BN_4fa71554" }, { "content": "class PatchEmbed(nn.Module):\n \"\"\"\n Embed images into patches and project them into a specified embedding dimension.\n\n This module converts input images into patch embeddings using a sequence of convolutional layers,\n effectively downsampling the spatial dimensions while increasing the channel dimension.\n\n Attributes:\n patches_resolution (Tuple[int, int]): Resolution of the patches after embedding.\n num_patches (int): Total number of patches.\n in_chans (int): Number of input channels.\n embed_dim (int): Dimension of the embedding.\n seq (nn.Sequential): Sequence of convolutional and activation layers for patch embedding.\n\n Examples:\n >>> import torch\n >>> patch_embed = PatchEmbed(in_chans=3, embed_dim=96, resolution=224, activation=nn.GELU)\n >>> x = torch.randn(1, 3, 224, 224)\n >>> output = patch_embed(x)\n >>> print(output.shape)\n torch.Size([1, 96, 56, 56])\n \"\"\"\n\n def __init__(self, in_chans: int, embed_dim: int, resolution: int, activation):\n \"\"\"\n Initialize patch embedding with convolutional layers for image-to-patch conversion and projection.\n\n Args:\n in_chans (int): Number of input channels.\n embed_dim (int): Dimension of the embedding.\n resolution (int): Input image resolution.\n activation (nn.Module): Activation function to use between convolutions.\n \"\"\"\n super().__init__()\n img_size: Tuple[int, int] = to_2tuple(resolution)\n self.patches_resolution = (img_size[0] // 4, img_size[1] // 4)\n self.num_patches = self.patches_resolution[0] * self.patches_resolution[1]\n self.in_chans = in_chans\n self.embed_dim = embed_dim\n n = embed_dim\n self.seq = nn.Sequential(\n Conv2d_BN(in_chans, n // 2, 3, 2, 1),\n activation(),\n Conv2d_BN(n // 2, n, 3, 2, 1),\n )\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Process input tensor through patch embedding sequence, converting images to patch embeddings.\"\"\"\n return self.seq(x)", "chunk_type": "class", "name": "PatchEmbed", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\tiny_encoder.py", "start_line": 75, "end_line": 123, "start_col": 0, "end_col": 26, "parent_name": null, "docstring": "Embed images into patches and project them into a specified embedding dimension.\n\nThis module converts input images into patch embeddings using a sequence of convolutional layers,\neffectively downsampling the spatial dimensions while increasing the channel dimension.\n\nAttributes:\n patches_resolution (Tuple[int, int]): Resolution of the patches after embedding.\n num_patches (int): Total number of patches.\n in_chans (int): Number of input channels.\n embed_dim (int): Dimension of the embedding.\n seq (nn.Sequential): Sequence of convolutional and activation layers for patch embedding.\n\nExamples:\n >>> import torch\n >>> patch_embed = PatchEmbed(in_chans=3, embed_dim=96, resolution=224, activation=nn.GELU)\n >>> x = torch.randn(1, 3, 224, 224)\n >>> output = patch_embed(x)\n >>> print(output.shape)\n torch.Size([1, 96, 56, 56])", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "itertools", "typing.List", "typing.Optional", "typing.Tuple", "typing.Union", "torch", "torch.nn", "torch.nn.functional", "ultralytics.nn.modules.LayerNorm2d", "ultralytics.utils.instance.to_2tuple", "nn.Module" ], "chunk_id": "class_PatchEmbed_c1f155b0" }, { "content": "class MBConv(nn.Module):\n \"\"\"\n Mobile Inverted Bottleneck Conv (MBConv) layer, part of the EfficientNet architecture.\n\n This module implements the mobile inverted bottleneck convolution with expansion, depthwise convolution,\n and projection phases, along with residual connections for improved gradient flow.\n\n Attributes:\n in_chans (int): Number of input channels.\n hidden_chans (int): Number of hidden channels after expansion.\n out_chans (int): Number of output channels.\n conv1 (Conv2d_BN): First convolutional layer for channel expansion.\n act1 (nn.Module): First activation function.\n conv2 (Conv2d_BN): Depthwise convolutional layer.\n act2 (nn.Module): Second activation function.\n conv3 (Conv2d_BN): Final convolutional layer for projection.\n act3 (nn.Module): Third activation function.\n drop_path (nn.Module): Drop path layer (Identity for inference).\n\n Examples:\n >>> in_chans, out_chans = 32, 64\n >>> mbconv = MBConv(in_chans, out_chans, expand_ratio=4, activation=nn.ReLU, drop_path=0.1)\n >>> x = torch.randn(1, in_chans, 56, 56)\n >>> output = mbconv(x)\n >>> print(output.shape)\n torch.Size([1, 64, 56, 56])\n \"\"\"\n\n def __init__(self, in_chans: int, out_chans: int, expand_ratio: float, activation, drop_path: float):\n \"\"\"\n Initialize the MBConv layer with specified input/output channels, expansion ratio, and activation.\n\n Args:\n in_chans (int): Number of input channels.\n out_chans (int): Number of output channels.\n expand_ratio (float): Channel expansion ratio for the hidden layer.\n activation (nn.Module): Activation function to use.\n drop_path (float): Drop path rate for stochastic depth.\n \"\"\"\n super().__init__()\n self.in_chans = in_chans\n self.hidden_chans = int(in_chans * expand_ratio)\n self.out_chans = out_chans\n\n self.conv1 = Conv2d_BN(in_chans, self.hidden_chans, ks=1)\n self.act1 = activation()\n\n self.conv2 = Conv2d_BN(self.hidden_chans, self.hidden_chans, ks=3, stride=1, pad=1, groups=self.hidden_chans)\n self.act2 = activation()\n\n self.conv3 = Conv2d_BN(self.hidden_chans, out_chans, ks=1, bn_weight_init=0.0)\n self.act3 = activation()\n\n # NOTE: `DropPath` is needed only for training.\n # self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n self.drop_path = nn.Identity()\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Implement the forward pass of MBConv, applying convolutions and skip connection.\"\"\"\n shortcut = x\n x = self.conv1(x)\n x = self.act1(x)\n x = self.conv2(x)\n x = self.act2(x)\n x = self.conv3(x)\n x = self.drop_path(x)\n x += shortcut\n return self.act3(x)", "chunk_type": "class", "name": "MBConv", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\tiny_encoder.py", "start_line": 126, "end_line": 193, "start_col": 0, "end_col": 27, "parent_name": null, "docstring": "Mobile Inverted Bottleneck Conv (MBConv) layer, part of the EfficientNet architecture.\n\nThis module implements the mobile inverted bottleneck convolution with expansion, depthwise convolution,\nand projection phases, along with residual connections for improved gradient flow.\n\nAttributes:\n in_chans (int): Number of input channels.\n hidden_chans (int): Number of hidden channels after expansion.\n out_chans (int): Number of output channels.\n conv1 (Conv2d_BN): First convolutional layer for channel expansion.\n act1 (nn.Module): First activation function.\n conv2 (Conv2d_BN): Depthwise convolutional layer.\n act2 (nn.Module): Second activation function.\n conv3 (Conv2d_BN): Final convolutional layer for projection.\n act3 (nn.Module): Third activation function.\n drop_path (nn.Module): Drop path layer (Identity for inference).\n\nExamples:\n >>> in_chans, out_chans = 32, 64\n >>> mbconv = MBConv(in_chans, out_chans, expand_ratio=4, activation=nn.ReLU, drop_path=0.1)\n >>> x = torch.randn(1, in_chans, 56, 56)\n >>> output = mbconv(x)\n >>> print(output.shape)\n torch.Size([1, 64, 56, 56])", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "itertools", "typing.List", "typing.Optional", "typing.Tuple", "typing.Union", "torch", "torch.nn", "torch.nn.functional", "ultralytics.nn.modules.LayerNorm2d", "ultralytics.utils.instance.to_2tuple", "nn.Module" ], "chunk_id": "class_MBConv_189eda56" }, { "content": "class PatchMerging(nn.Module):\n \"\"\"\n Merge neighboring patches in the feature map and project to a new dimension.\n\n This class implements a patch merging operation that combines spatial information and adjusts the feature\n dimension using a series of convolutional layers with batch normalization. It effectively reduces spatial\n resolution while potentially increasing channel dimensions.\n\n Attributes:\n input_resolution (Tuple[int, int]): The input resolution (height, width) of the feature map.\n dim (int): The input dimension of the feature map.\n out_dim (int): The output dimension after merging and projection.\n act (nn.Module): The activation function used between convolutions.\n conv1 (Conv2d_BN): The first convolutional layer for dimension projection.\n conv2 (Conv2d_BN): The second convolutional layer for spatial merging.\n conv3 (Conv2d_BN): The third convolutional layer for final projection.\n\n Examples:\n >>> input_resolution = (56, 56)\n >>> patch_merging = PatchMerging(input_resolution, dim=64, out_dim=128, activation=nn.ReLU)\n >>> x = torch.randn(4, 64, 56, 56)\n >>> output = patch_merging(x)\n >>> print(output.shape)\n torch.Size([4, 3136, 128])\n \"\"\"\n\n def __init__(self, input_resolution: Tuple[int, int], dim: int, out_dim: int, activation):\n \"\"\"\n Initialize the PatchMerging module for merging and projecting neighboring patches in feature maps.\n\n Args:\n input_resolution (Tuple[int, int]): The input resolution (height, width) of the feature map.\n dim (int): The input dimension of the feature map.\n out_dim (int): The output dimension after merging and projection.\n activation (nn.Module): The activation function used between convolutions.\n \"\"\"\n super().__init__()\n\n self.input_resolution = input_resolution\n self.dim = dim\n self.out_dim = out_dim\n self.act = activation()\n self.conv1 = Conv2d_BN(dim, out_dim, 1, 1, 0)\n stride_c = 1 if out_dim in {320, 448, 576} else 2\n self.conv2 = Conv2d_BN(out_dim, out_dim, 3, stride_c, 1, groups=out_dim)\n self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0)\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Apply patch merging and dimension projection to the input feature map.\"\"\"\n if x.ndim == 3:\n H, W = self.input_resolution\n B = len(x)\n # (B, C, H, W)\n x = x.view(B, H, W, -1).permute(0, 3, 1, 2)\n\n x = self.conv1(x)\n x = self.act(x)\n\n x = self.conv2(x)\n x = self.act(x)\n x = self.conv3(x)\n return x.flatten(2).transpose(1, 2)", "chunk_type": "class", "name": "PatchMerging", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\tiny_encoder.py", "start_line": 196, "end_line": 257, "start_col": 0, "end_col": 43, "parent_name": null, "docstring": "Merge neighboring patches in the feature map and project to a new dimension.\n\nThis class implements a patch merging operation that combines spatial information and adjusts the feature\ndimension using a series of convolutional layers with batch normalization. It effectively reduces spatial\nresolution while potentially increasing channel dimensions.\n\nAttributes:\n input_resolution (Tuple[int, int]): The input resolution (height, width) of the feature map.\n dim (int): The input dimension of the feature map.\n out_dim (int): The output dimension after merging and projection.\n act (nn.Module): The activation function used between convolutions.\n conv1 (Conv2d_BN): The first convolutional layer for dimension projection.\n conv2 (Conv2d_BN): The second convolutional layer for spatial merging.\n conv3 (Conv2d_BN): The third convolutional layer for final projection.\n\nExamples:\n >>> input_resolution = (56, 56)\n >>> patch_merging = PatchMerging(input_resolution, dim=64, out_dim=128, activation=nn.ReLU)\n >>> x = torch.randn(4, 64, 56, 56)\n >>> output = patch_merging(x)\n >>> print(output.shape)\n torch.Size([4, 3136, 128])", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "itertools", "typing.List", "typing.Optional", "typing.Tuple", "typing.Union", "torch", "torch.nn", "torch.nn.functional", "ultralytics.nn.modules.LayerNorm2d", "ultralytics.utils.instance.to_2tuple", "nn.Module" ], "chunk_id": "class_PatchMerging_a0111844" }, { "content": "class ConvLayer(nn.Module):\n \"\"\"\n Convolutional Layer featuring multiple MobileNetV3-style inverted bottleneck convolutions (MBConv).\n\n This layer optionally applies downsample operations to the output and supports gradient checkpointing\n for memory efficiency during training.\n\n Attributes:\n dim (int): Dimensionality of the input and output.\n input_resolution (Tuple[int, int]): Resolution of the input image.\n depth (int): Number of MBConv layers in the block.\n use_checkpoint (bool): Whether to use gradient checkpointing to save memory.\n blocks (nn.ModuleList): List of MBConv layers.\n downsample (Optional[nn.Module]): Function for downsampling the output.\n\n Examples:\n >>> input_tensor = torch.randn(1, 64, 56, 56)\n >>> conv_layer = ConvLayer(64, (56, 56), depth=3, activation=nn.ReLU)\n >>> output = conv_layer(input_tensor)\n >>> print(output.shape)\n torch.Size([1, 3136, 128])\n \"\"\"\n\n def __init__(\n self,\n dim: int,\n input_resolution: Tuple[int, int],\n depth: int,\n activation,\n drop_path: Union[float, List[float]] = 0.0,\n downsample: Optional[nn.Module] = None,\n use_checkpoint: bool = False,\n out_dim: Optional[int] = None,\n conv_expand_ratio: float = 4.0,\n ):\n \"\"\"\n Initialize the ConvLayer with the given dimensions and settings.\n\n This layer consists of multiple MobileNetV3-style inverted bottleneck convolutions (MBConv) and\n optionally applies downsampling to the output.\n\n Args:\n dim (int): The dimensionality of the input and output.\n input_resolution (Tuple[int, int]): The resolution of the input image.\n depth (int): The number of MBConv layers in the block.\n activation (nn.Module): Activation function applied after each convolution.\n drop_path (float | List[float], optional): Drop path rate. Single float or a list of floats for each MBConv.\n downsample (Optional[nn.Module], optional): Function for downsampling the output. None to skip downsampling.\n use_checkpoint (bool, optional): Whether to use gradient checkpointing to save memory.\n out_dim (Optional[int], optional): The dimensionality of the output. None means it will be the same as `dim`.\n conv_expand_ratio (float, optional): Expansion ratio for the MBConv layers.\n \"\"\"\n super().__init__()\n self.dim = dim\n self.input_resolution = input_resolution\n self.depth = depth\n self.use_checkpoint = use_checkpoint\n\n # Build blocks\n self.blocks = nn.ModuleList(\n [\n MBConv(\n dim,\n dim,\n conv_expand_ratio,\n activation,\n drop_path[i] if isinstance(drop_path, list) else drop_path,\n )\n for i in range(depth)\n ]\n )\n\n # Patch merging layer\n self.downsample = (\n None\n if downsample is None\n else downsample(input_resolution, dim=dim, out_dim=out_dim, activation=activation)\n )\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Process input through convolutional layers, applying MBConv blocks and optional downsampling.\"\"\"\n for blk in self.blocks:\n x = torch.utils.checkpoint(blk, x) if self.use_checkpoint else blk(x) # warn: checkpoint is slow import\n return x if self.downsample is None else self.downsample(x)", "chunk_type": "class", "name": "ConvLayer", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\tiny_encoder.py", "start_line": 260, "end_line": 343, "start_col": 0, "end_col": 67, "parent_name": null, "docstring": "Convolutional Layer featuring multiple MobileNetV3-style inverted bottleneck convolutions (MBConv).\n\nThis layer optionally applies downsample operations to the output and supports gradient checkpointing\nfor memory efficiency during training.\n\nAttributes:\n dim (int): Dimensionality of the input and output.\n input_resolution (Tuple[int, int]): Resolution of the input image.\n depth (int): Number of MBConv layers in the block.\n use_checkpoint (bool): Whether to use gradient checkpointing to save memory.\n blocks (nn.ModuleList): List of MBConv layers.\n downsample (Optional[nn.Module]): Function for downsampling the output.\n\nExamples:\n >>> input_tensor = torch.randn(1, 64, 56, 56)\n >>> conv_layer = ConvLayer(64, (56, 56), depth=3, activation=nn.ReLU)\n >>> output = conv_layer(input_tensor)\n >>> print(output.shape)\n torch.Size([1, 3136, 128])", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "itertools", "typing.List", "typing.Optional", "typing.Tuple", "typing.Union", "torch", "torch.nn", "torch.nn.functional", "ultralytics.nn.modules.LayerNorm2d", "ultralytics.utils.instance.to_2tuple", "nn.Module" ], "chunk_id": "class_ConvLayer_b164e6de" }, { "content": "class MLP(nn.Module):\n \"\"\"\n Multi-layer Perceptron (MLP) module for transformer architectures.\n\n This module applies layer normalization, two fully-connected layers with an activation function in between,\n and dropout. It is commonly used in transformer-based architectures for processing token embeddings.\n\n Attributes:\n norm (nn.LayerNorm): Layer normalization applied to the input.\n fc1 (nn.Linear): First fully-connected layer.\n fc2 (nn.Linear): Second fully-connected layer.\n act (nn.Module): Activation function applied after the first fully-connected layer.\n drop (nn.Dropout): Dropout layer applied after the activation function.\n\n Examples:\n >>> import torch\n >>> from torch import nn\n >>> mlp = MLP(in_features=256, hidden_features=512, out_features=256, activation=nn.GELU, drop=0.1)\n >>> x = torch.randn(32, 100, 256)\n >>> output = mlp(x)\n >>> print(output.shape)\n torch.Size([32, 100, 256])\n \"\"\"\n\n def __init__(\n self,\n in_features: int,\n hidden_features: Optional[int] = None,\n out_features: Optional[int] = None,\n activation=nn.GELU,\n drop: float = 0.0,\n ):\n \"\"\"\n Initialize a multi-layer perceptron with configurable input, hidden, and output dimensions.\n\n Args:\n in_features (int): Number of input features.\n hidden_features (Optional[int], optional): Number of hidden features.\n out_features (Optional[int], optional): Number of output features.\n activation (nn.Module): Activation function applied after the first fully-connected layer.\n drop (float, optional): Dropout probability.\n \"\"\"\n super().__init__()\n out_features = out_features or in_features\n hidden_features = hidden_features or in_features\n self.norm = nn.LayerNorm(in_features)\n self.fc1 = nn.Linear(in_features, hidden_features)\n self.fc2 = nn.Linear(hidden_features, out_features)\n self.act = activation()\n self.drop = nn.Dropout(drop)\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Apply MLP operations: layer norm, FC layers, activation, and dropout to the input tensor.\"\"\"\n x = self.norm(x)\n x = self.fc1(x)\n x = self.act(x)\n x = self.drop(x)\n x = self.fc2(x)\n return self.drop(x)", "chunk_type": "class", "name": "MLP", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\tiny_encoder.py", "start_line": 346, "end_line": 404, "start_col": 0, "end_col": 27, "parent_name": null, "docstring": "Multi-layer Perceptron (MLP) module for transformer architectures.\n\nThis module applies layer normalization, two fully-connected layers with an activation function in between,\nand dropout. It is commonly used in transformer-based architectures for processing token embeddings.\n\nAttributes:\n norm (nn.LayerNorm): Layer normalization applied to the input.\n fc1 (nn.Linear): First fully-connected layer.\n fc2 (nn.Linear): Second fully-connected layer.\n act (nn.Module): Activation function applied after the first fully-connected layer.\n drop (nn.Dropout): Dropout layer applied after the activation function.\n\nExamples:\n >>> import torch\n >>> from torch import nn\n >>> mlp = MLP(in_features=256, hidden_features=512, out_features=256, activation=nn.GELU, drop=0.1)\n >>> x = torch.randn(32, 100, 256)\n >>> output = mlp(x)\n >>> print(output.shape)\n torch.Size([32, 100, 256])", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "itertools", "typing.List", "typing.Optional", "typing.Tuple", "typing.Union", "torch", "torch.nn", "torch.nn.functional", "ultralytics.nn.modules.LayerNorm2d", "ultralytics.utils.instance.to_2tuple", "nn.Module" ], "chunk_id": "class_MLP_abe157c9" }, { "content": "class Attention(torch.nn.Module):\n \"\"\"\n Multi-head attention module with spatial awareness and trainable attention biases.\n\n This module implements a multi-head attention mechanism with support for spatial awareness, applying\n attention biases based on spatial resolution. It includes trainable attention biases for each unique\n offset between spatial positions in the resolution grid.\n\n Attributes:\n num_heads (int): Number of attention heads.\n scale (float): Scaling factor for attention scores.\n key_dim (int): Dimensionality of the keys and queries.\n nh_kd (int): Product of num_heads and key_dim.\n d (int): Dimensionality of the value vectors.\n dh (int): Product of d and num_heads.\n attn_ratio (float): Attention ratio affecting the dimensions of the value vectors.\n norm (nn.LayerNorm): Layer normalization applied to input.\n qkv (nn.Linear): Linear layer for computing query, key, and value projections.\n proj (nn.Linear): Linear layer for final projection.\n attention_biases (nn.Parameter): Learnable attention biases.\n attention_bias_idxs (torch.Tensor): Indices for attention biases.\n ab (torch.Tensor): Cached attention biases for inference, deleted during training.\n\n Examples:\n >>> attn = Attention(dim=256, key_dim=64, num_heads=8, resolution=(14, 14))\n >>> x = torch.randn(1, 196, 256)\n >>> output = attn(x)\n >>> print(output.shape)\n torch.Size([1, 196, 256])\n \"\"\"\n\n def __init__(\n self,\n dim: int,\n key_dim: int,\n num_heads: int = 8,\n attn_ratio: float = 4,\n resolution: Tuple[int, int] = (14, 14),\n ):\n \"\"\"\n Initialize the Attention module for multi-head attention with spatial awareness.\n\n This module implements a multi-head attention mechanism with support for spatial awareness, applying\n attention biases based on spatial resolution. It includes trainable attention biases for each unique\n offset between spatial positions in the resolution grid.\n\n Args:\n dim (int): The dimensionality of the input and output.\n key_dim (int): The dimensionality of the keys and queries.\n num_heads (int, optional): Number of attention heads.\n attn_ratio (float, optional): Attention ratio, affecting the dimensions of the value vectors.\n resolution (Tuple[int, int], optional): Spatial resolution of the input feature map.\n \"\"\"\n super().__init__()\n\n assert isinstance(resolution, tuple) and len(resolution) == 2, \"'resolution' argument not tuple of length 2\"\n self.num_heads = num_heads\n self.scale = key_dim**-0.5\n self.key_dim = key_dim\n self.nh_kd = nh_kd = key_dim * num_heads\n self.d = int(attn_ratio * key_dim)\n self.dh = int(attn_ratio * key_dim) * num_heads\n self.attn_ratio = attn_ratio\n h = self.dh + nh_kd * 2\n\n self.norm = nn.LayerNorm(dim)\n self.qkv = nn.Linear(dim, h)\n self.proj = nn.Linear(self.dh, dim)\n\n points = list(itertools.product(range(resolution[0]), range(resolution[1])))\n N = len(points)\n attention_offsets = {}\n idxs = []\n for p1 in points:\n for p2 in points:\n offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))\n if offset not in attention_offsets:\n attention_offsets[offset] = len(attention_offsets)\n idxs.append(attention_offsets[offset])\n self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))\n self.register_buffer(\"attention_bias_idxs\", torch.LongTensor(idxs).view(N, N), persistent=False)\n\n @torch.no_grad()\n def train(self, mode: bool = True):\n \"\"\"Set the module in training mode and handle the 'ab' attribute for cached attention biases.\"\"\"\n super().train(mode)\n if mode and hasattr(self, \"ab\"):\n del self.ab\n else:\n self.ab = self.attention_biases[:, self.attention_bias_idxs]\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Apply multi-head attention with spatial awareness and trainable attention biases.\"\"\"\n B, N, _ = x.shape # B, N, C\n\n # Normalization\n x = self.norm(x)\n\n qkv = self.qkv(x)\n # (B, N, num_heads, d)\n q, k, v = qkv.view(B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.d], dim=3)\n # (B, num_heads, N, d)\n q = q.permute(0, 2, 1, 3)\n k = k.permute(0, 2, 1, 3)\n v = v.permute(0, 2, 1, 3)\n self.ab = self.ab.to(self.attention_biases.device)\n\n attn = (q @ k.transpose(-2, -1)) * self.scale + (\n self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab\n )\n attn = attn.softmax(dim=-1)\n x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)\n return self.proj(x)", "chunk_type": "class", "name": "Attention", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\tiny_encoder.py", "start_line": 407, "end_line": 519, "start_col": 0, "end_col": 27, "parent_name": null, "docstring": "Multi-head attention module with spatial awareness and trainable attention biases.\n\nThis module implements a multi-head attention mechanism with support for spatial awareness, applying\nattention biases based on spatial resolution. It includes trainable attention biases for each unique\noffset between spatial positions in the resolution grid.\n\nAttributes:\n num_heads (int): Number of attention heads.\n scale (float): Scaling factor for attention scores.\n key_dim (int): Dimensionality of the keys and queries.\n nh_kd (int): Product of num_heads and key_dim.\n d (int): Dimensionality of the value vectors.\n dh (int): Product of d and num_heads.\n attn_ratio (float): Attention ratio affecting the dimensions of the value vectors.\n norm (nn.LayerNorm): Layer normalization applied to input.\n qkv (nn.Linear): Linear layer for computing query, key, and value projections.\n proj (nn.Linear): Linear layer for final projection.\n attention_biases (nn.Parameter): Learnable attention biases.\n attention_bias_idxs (torch.Tensor): Indices for attention biases.\n ab (torch.Tensor): Cached attention biases for inference, deleted during training.\n\nExamples:\n >>> attn = Attention(dim=256, key_dim=64, num_heads=8, resolution=(14, 14))\n >>> x = torch.randn(1, 196, 256)\n >>> output = attn(x)\n >>> print(output.shape)\n torch.Size([1, 196, 256])", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "itertools", "typing.List", "typing.Optional", "typing.Tuple", "typing.Union", "torch", "torch.nn", "torch.nn.functional", "ultralytics.nn.modules.LayerNorm2d", "ultralytics.utils.instance.to_2tuple", "torch.nn.Module" ], "chunk_id": "class_Attention_04fc8140" }, { "content": "class TinyViTBlock(nn.Module):\n \"\"\"\n TinyViT Block that applies self-attention and a local convolution to the input.\n\n This block is a key component of the TinyViT architecture, combining self-attention mechanisms with\n local convolutions to process input features efficiently. It supports windowed attention for\n computational efficiency and includes residual connections.\n\n Attributes:\n dim (int): The dimensionality of the input and output.\n input_resolution (Tuple[int, int]): Spatial resolution of the input feature map.\n num_heads (int): Number of attention heads.\n window_size (int): Size of the attention window.\n mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension.\n drop_path (nn.Module): Stochastic depth layer, identity function during inference.\n attn (Attention): Self-attention module.\n mlp (MLP): Multi-layer perceptron module.\n local_conv (Conv2d_BN): Depth-wise local convolution layer.\n\n Examples:\n >>> input_tensor = torch.randn(1, 196, 192)\n >>> block = TinyViTBlock(dim=192, input_resolution=(14, 14), num_heads=3)\n >>> output = block(input_tensor)\n >>> print(output.shape)\n torch.Size([1, 196, 192])\n \"\"\"\n\n def __init__(\n self,\n dim: int,\n input_resolution: Tuple[int, int],\n num_heads: int,\n window_size: int = 7,\n mlp_ratio: float = 4.0,\n drop: float = 0.0,\n drop_path: float = 0.0,\n local_conv_size: int = 3,\n activation=nn.GELU,\n ):\n \"\"\"\n Initialize a TinyViT block with self-attention and local convolution.\n\n This block is a key component of the TinyViT architecture, combining self-attention mechanisms with\n local convolutions to process input features efficiently.\n\n Args:\n dim (int): Dimensionality of the input and output features.\n input_resolution (Tuple[int, int]): Spatial resolution of the input feature map (height, width).\n num_heads (int): Number of attention heads.\n window_size (int, optional): Size of the attention window. Must be greater than 0.\n mlp_ratio (float, optional): Ratio of MLP hidden dimension to embedding dimension.\n drop (float, optional): Dropout rate.\n drop_path (float, optional): Stochastic depth rate.\n local_conv_size (int, optional): Kernel size of the local convolution.\n activation (nn.Module): Activation function for MLP.\n \"\"\"\n super().__init__()\n self.dim = dim\n self.input_resolution = input_resolution\n self.num_heads = num_heads\n assert window_size > 0, \"window_size must be greater than 0\"\n self.window_size = window_size\n self.mlp_ratio = mlp_ratio\n\n # NOTE: `DropPath` is needed only for training.\n # self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n self.drop_path = nn.Identity()\n\n assert dim % num_heads == 0, \"dim must be divisible by num_heads\"\n head_dim = dim // num_heads\n\n window_resolution = (window_size, window_size)\n self.attn = Attention(dim, head_dim, num_heads, attn_ratio=1, resolution=window_resolution)\n\n mlp_hidden_dim = int(dim * mlp_ratio)\n mlp_activation = activation\n self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, activation=mlp_activation, drop=drop)\n\n pad = local_conv_size // 2\n self.local_conv = Conv2d_BN(dim, dim, ks=local_conv_size, stride=1, pad=pad, groups=dim)\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Apply self-attention, local convolution, and MLP operations to the input tensor.\"\"\"\n h, w = self.input_resolution\n b, hw, c = x.shape # batch, height*width, channels\n assert hw == h * w, \"input feature has wrong size\"\n res_x = x\n if h == self.window_size and w == self.window_size:\n x = self.attn(x)\n else:\n x = x.view(b, h, w, c)\n pad_b = (self.window_size - h % self.window_size) % self.window_size\n pad_r = (self.window_size - w % self.window_size) % self.window_size\n padding = pad_b > 0 or pad_r > 0\n if padding:\n x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b))\n\n pH, pW = h + pad_b, w + pad_r\n nH = pH // self.window_size\n nW = pW // self.window_size\n\n # Window partition\n x = (\n x.view(b, nH, self.window_size, nW, self.window_size, c)\n .transpose(2, 3)\n .reshape(b * nH * nW, self.window_size * self.window_size, c)\n )\n x = self.attn(x)\n\n # Window reverse\n x = x.view(b, nH, nW, self.window_size, self.window_size, c).transpose(2, 3).reshape(b, pH, pW, c)\n if padding:\n x = x[:, :h, :w].contiguous()\n\n x = x.view(b, hw, c)\n\n x = res_x + self.drop_path(x)\n x = x.transpose(1, 2).reshape(b, c, h, w)\n x = self.local_conv(x)\n x = x.view(b, c, hw).transpose(1, 2)\n\n return x + self.drop_path(self.mlp(x))\n\n def extra_repr(self) -> str:\n \"\"\"\n Return a string representation of the TinyViTBlock's parameters.\n\n This method provides a formatted string containing key information about the TinyViTBlock, including its\n dimension, input resolution, number of attention heads, window size, and MLP ratio.\n\n Returns:\n (str): A formatted string containing the block's parameters.\n\n Examples:\n >>> block = TinyViTBlock(dim=192, input_resolution=(14, 14), num_heads=3, window_size=7, mlp_ratio=4.0)\n >>> print(block.extra_repr())\n dim=192, input_resolution=(14, 14), num_heads=3, window_size=7, mlp_ratio=4.0\n \"\"\"\n return (\n f\"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, \"\n f\"window_size={self.window_size}, mlp_ratio={self.mlp_ratio}\"\n )", "chunk_type": "class", "name": "TinyViTBlock", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\tiny_encoder.py", "start_line": 522, "end_line": 663, "start_col": 0, "end_col": 9, "parent_name": null, "docstring": "TinyViT Block that applies self-attention and a local convolution to the input.\n\nThis block is a key component of the TinyViT architecture, combining self-attention mechanisms with\nlocal convolutions to process input features efficiently. It supports windowed attention for\ncomputational efficiency and includes residual connections.\n\nAttributes:\n dim (int): The dimensionality of the input and output.\n input_resolution (Tuple[int, int]): Spatial resolution of the input feature map.\n num_heads (int): Number of attention heads.\n window_size (int): Size of the attention window.\n mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension.\n drop_path (nn.Module): Stochastic depth layer, identity function during inference.\n attn (Attention): Self-attention module.\n mlp (MLP): Multi-layer perceptron module.\n local_conv (Conv2d_BN): Depth-wise local convolution layer.\n\nExamples:\n >>> input_tensor = torch.randn(1, 196, 192)\n >>> block = TinyViTBlock(dim=192, input_resolution=(14, 14), num_heads=3)\n >>> output = block(input_tensor)\n >>> print(output.shape)\n torch.Size([1, 196, 192])", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "itertools", "typing.List", "typing.Optional", "typing.Tuple", "typing.Union", "torch", "torch.nn", "torch.nn.functional", "ultralytics.nn.modules.LayerNorm2d", "ultralytics.utils.instance.to_2tuple", "nn.Module" ], "chunk_id": "class_TinyViTBlock_f2ae92c6" }, { "content": "class BasicLayer(nn.Module):\n \"\"\"\n A basic TinyViT layer for one stage in a TinyViT architecture.\n\n This class represents a single layer in the TinyViT model, consisting of multiple TinyViT blocks\n and an optional downsampling operation. It processes features at a specific resolution and\n dimensionality within the overall architecture.\n\n Attributes:\n dim (int): The dimensionality of the input and output features.\n input_resolution (Tuple[int, int]): Spatial resolution of the input feature map.\n depth (int): Number of TinyViT blocks in this layer.\n use_checkpoint (bool): Whether to use gradient checkpointing to save memory.\n blocks (nn.ModuleList): List of TinyViT blocks that make up this layer.\n downsample (nn.Module | None): Downsample layer at the end of the layer, if specified.\n\n Examples:\n >>> input_tensor = torch.randn(1, 3136, 192)\n >>> layer = BasicLayer(dim=192, input_resolution=(56, 56), depth=2, num_heads=3, window_size=7)\n >>> output = layer(input_tensor)\n >>> print(output.shape)\n torch.Size([1, 784, 384])\n \"\"\"\n\n def __init__(\n self,\n dim: int,\n input_resolution: Tuple[int, int],\n depth: int,\n num_heads: int,\n window_size: int,\n mlp_ratio: float = 4.0,\n drop: float = 0.0,\n drop_path: Union[float, List[float]] = 0.0,\n downsample: Optional[nn.Module] = None,\n use_checkpoint: bool = False,\n local_conv_size: int = 3,\n activation=nn.GELU,\n out_dim: Optional[int] = None,\n ):\n \"\"\"\n Initialize a BasicLayer in the TinyViT architecture.\n\n This layer consists of multiple TinyViT blocks and an optional downsampling operation. It is designed to\n process feature maps at a specific resolution and dimensionality within the TinyViT model.\n\n Args:\n dim (int): Dimensionality of the input and output features.\n input_resolution (Tuple[int, int]): Spatial resolution of the input feature map (height, width).\n depth (int): Number of TinyViT blocks in this layer.\n num_heads (int): Number of attention heads in each TinyViT block.\n window_size (int): Size of the local window for attention computation.\n mlp_ratio (float, optional): Ratio of MLP hidden dimension to embedding dimension.\n drop (float, optional): Dropout rate.\n drop_path (float | List[float], optional): Stochastic depth rate. Can be a float or a list of floats for each block.\n downsample (nn.Module | None, optional): Downsampling layer at the end of the layer. None to skip downsampling.\n use_checkpoint (bool, optional): Whether to use gradient checkpointing to save memory.\n local_conv_size (int, optional): Kernel size for the local convolution in each TinyViT block.\n activation (nn.Module): Activation function used in the MLP.\n out_dim (int | None, optional): Output dimension after downsampling. None means it will be the same as `dim`.\n \"\"\"\n super().__init__()\n self.dim = dim\n self.input_resolution = input_resolution\n self.depth = depth\n self.use_checkpoint = use_checkpoint\n\n # Build blocks\n self.blocks = nn.ModuleList(\n [\n TinyViTBlock(\n dim=dim,\n input_resolution=input_resolution,\n num_heads=num_heads,\n window_size=window_size,\n mlp_ratio=mlp_ratio,\n drop=drop,\n drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,\n local_conv_size=local_conv_size,\n activation=activation,\n )\n for i in range(depth)\n ]\n )\n\n # Patch merging layer\n self.downsample = (\n None\n if downsample is None\n else downsample(input_resolution, dim=dim, out_dim=out_dim, activation=activation)\n )\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Process input through TinyViT blocks and optional downsampling.\"\"\"\n for blk in self.blocks:\n x = torch.utils.checkpoint(blk, x) if self.use_checkpoint else blk(x) # warn: checkpoint is slow import\n return x if self.downsample is None else self.downsample(x)\n\n def extra_repr(self) -> str:\n \"\"\"Return a string with the layer's parameters for printing.\"\"\"\n return f\"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}\"", "chunk_type": "class", "name": "BasicLayer", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\tiny_encoder.py", "start_line": 666, "end_line": 766, "start_col": 0, "end_col": 94, "parent_name": null, "docstring": "A basic TinyViT layer for one stage in a TinyViT architecture.\n\nThis class represents a single layer in the TinyViT model, consisting of multiple TinyViT blocks\nand an optional downsampling operation. It processes features at a specific resolution and\ndimensionality within the overall architecture.\n\nAttributes:\n dim (int): The dimensionality of the input and output features.\n input_resolution (Tuple[int, int]): Spatial resolution of the input feature map.\n depth (int): Number of TinyViT blocks in this layer.\n use_checkpoint (bool): Whether to use gradient checkpointing to save memory.\n blocks (nn.ModuleList): List of TinyViT blocks that make up this layer.\n downsample (nn.Module | None): Downsample layer at the end of the layer, if specified.\n\nExamples:\n >>> input_tensor = torch.randn(1, 3136, 192)\n >>> layer = BasicLayer(dim=192, input_resolution=(56, 56), depth=2, num_heads=3, window_size=7)\n >>> output = layer(input_tensor)\n >>> print(output.shape)\n torch.Size([1, 784, 384])", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "itertools", "typing.List", "typing.Optional", "typing.Tuple", "typing.Union", "torch", "torch.nn", "torch.nn.functional", "ultralytics.nn.modules.LayerNorm2d", "ultralytics.utils.instance.to_2tuple", "nn.Module" ], "chunk_id": "class_BasicLayer_d9b97223" }, { "content": "class TinyViT(nn.Module):\n \"\"\"\n TinyViT: A compact vision transformer architecture for efficient image classification and feature extraction.\n\n This class implements the TinyViT model, which combines elements of vision transformers and convolutional\n neural networks for improved efficiency and performance on vision tasks. It features hierarchical processing\n with patch embedding, multiple stages of attention and convolution blocks, and a feature refinement neck.\n\n Attributes:\n img_size (int): Input image size.\n num_classes (int): Number of classification classes.\n depths (Tuple[int, int, int, int]): Number of blocks in each stage.\n num_layers (int): Total number of layers in the network.\n mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension.\n patch_embed (PatchEmbed): Module for patch embedding.\n patches_resolution (Tuple[int, int]): Resolution of embedded patches.\n layers (nn.ModuleList): List of network layers.\n norm_head (nn.LayerNorm): Layer normalization for the classifier head.\n head (nn.Linear): Linear layer for final classification.\n neck (nn.Sequential): Neck module for feature refinement.\n\n Examples:\n >>> model = TinyViT(img_size=224, num_classes=1000)\n >>> x = torch.randn(1, 3, 224, 224)\n >>> features = model.forward_features(x)\n >>> print(features.shape)\n torch.Size([1, 256, 56, 56])\n \"\"\"\n\n def __init__(\n self,\n img_size: int = 224,\n in_chans: int = 3,\n num_classes: int = 1000,\n embed_dims: Tuple[int, int, int, int] = (96, 192, 384, 768),\n depths: Tuple[int, int, int, int] = (2, 2, 6, 2),\n num_heads: Tuple[int, int, int, int] = (3, 6, 12, 24),\n window_sizes: Tuple[int, int, int, int] = (7, 7, 14, 7),\n mlp_ratio: float = 4.0,\n drop_rate: float = 0.0,\n drop_path_rate: float = 0.1,\n use_checkpoint: bool = False,\n mbconv_expand_ratio: float = 4.0,\n local_conv_size: int = 3,\n layer_lr_decay: float = 1.0,\n ):\n \"\"\"\n Initialize the TinyViT model.\n\n This constructor sets up the TinyViT architecture, including patch embedding, multiple layers of\n attention and convolution blocks, and a classification head.\n\n Args:\n img_size (int, optional): Size of the input image.\n in_chans (int, optional): Number of input channels.\n num_classes (int, optional): Number of classes for classification.\n embed_dims (Tuple[int, int, int, int], optional): Embedding dimensions for each stage.\n depths (Tuple[int, int, int, int], optional): Number of blocks in each stage.\n num_heads (Tuple[int, int, int, int], optional): Number of attention heads in each stage.\n window_sizes (Tuple[int, int, int, int], optional): Window sizes for each stage.\n mlp_ratio (float, optional): Ratio of MLP hidden dim to embedding dim.\n drop_rate (float, optional): Dropout rate.\n drop_path_rate (float, optional): Stochastic depth rate.\n use_checkpoint (bool, optional): Whether to use checkpointing to save memory.\n mbconv_expand_ratio (float, optional): Expansion ratio for MBConv layer.\n local_conv_size (int, optional): Kernel size for local convolutions.\n layer_lr_decay (float, optional): Layer-wise learning rate decay factor.\n \"\"\"\n super().__init__()\n self.img_size = img_size\n self.num_classes = num_classes\n self.depths = depths\n self.num_layers = len(depths)\n self.mlp_ratio = mlp_ratio\n\n activation = nn.GELU\n\n self.patch_embed = PatchEmbed(\n in_chans=in_chans, embed_dim=embed_dims[0], resolution=img_size, activation=activation\n )\n\n patches_resolution = self.patch_embed.patches_resolution\n self.patches_resolution = patches_resolution\n\n # Stochastic depth\n dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule\n\n # Build layers\n self.layers = nn.ModuleList()\n for i_layer in range(self.num_layers):\n kwargs = dict(\n dim=embed_dims[i_layer],\n input_resolution=(\n patches_resolution[0] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer)),\n patches_resolution[1] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer)),\n ),\n # input_resolution=(patches_resolution[0] // (2 ** i_layer),\n # patches_resolution[1] // (2 ** i_layer)),\n depth=depths[i_layer],\n drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],\n downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,\n use_checkpoint=use_checkpoint,\n out_dim=embed_dims[min(i_layer + 1, len(embed_dims) - 1)],\n activation=activation,\n )\n if i_layer == 0:\n layer = ConvLayer(conv_expand_ratio=mbconv_expand_ratio, **kwargs)\n else:\n layer = BasicLayer(\n num_heads=num_heads[i_layer],\n window_size=window_sizes[i_layer],\n mlp_ratio=self.mlp_ratio,\n drop=drop_rate,\n local_conv_size=local_conv_size,\n **kwargs,\n )\n self.layers.append(layer)\n\n # Classifier head\n self.norm_head = nn.LayerNorm(embed_dims[-1])\n self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else torch.nn.Identity()\n\n # Init weights\n self.apply(self._init_weights)\n self.set_layer_lr_decay(layer_lr_decay)\n self.neck = nn.Sequential(\n nn.Conv2d(\n embed_dims[-1],\n 256,\n kernel_size=1,\n bias=False,\n ),\n LayerNorm2d(256),\n nn.Conv2d(\n 256,\n 256,\n kernel_size=3,\n padding=1,\n bias=False,\n ),\n LayerNorm2d(256),\n )\n\n def set_layer_lr_decay(self, layer_lr_decay: float):\n \"\"\"Set layer-wise learning rate decay for the TinyViT model based on depth.\"\"\"\n decay_rate = layer_lr_decay\n\n # Layers -> blocks (depth)\n depth = sum(self.depths)\n lr_scales = [decay_rate ** (depth - i - 1) for i in range(depth)]\n\n def _set_lr_scale(m, scale):\n \"\"\"Set the learning rate scale for each layer in the model based on the layer's depth.\"\"\"\n for p in m.parameters():\n p.lr_scale = scale\n\n self.patch_embed.apply(lambda x: _set_lr_scale(x, lr_scales[0]))\n i = 0\n for layer in self.layers:\n for block in layer.blocks:\n block.apply(lambda x: _set_lr_scale(x, lr_scales[i]))\n i += 1\n if layer.downsample is not None:\n layer.downsample.apply(lambda x: _set_lr_scale(x, lr_scales[i - 1]))\n assert i == depth\n for m in {self.norm_head, self.head}:\n m.apply(lambda x: _set_lr_scale(x, lr_scales[-1]))\n\n for k, p in self.named_parameters():\n p.param_name = k\n\n def _check_lr_scale(m):\n \"\"\"Check if the learning rate scale attribute is present in module's parameters.\"\"\"\n for p in m.parameters():\n assert hasattr(p, \"lr_scale\"), p.param_name\n\n self.apply(_check_lr_scale)\n\n @staticmethod\n def _init_weights(m):\n \"\"\"Initialize weights for linear and normalization layers in the TinyViT model.\"\"\"\n if isinstance(m, nn.Linear):\n # NOTE: This initialization is needed only for training.\n # trunc_normal_(m.weight, std=.02)\n if m.bias is not None:\n nn.init.constant_(m.bias, 0)\n elif isinstance(m, nn.LayerNorm):\n nn.init.constant_(m.bias, 0)\n nn.init.constant_(m.weight, 1.0)\n\n @torch.jit.ignore\n def no_weight_decay_keywords(self):\n \"\"\"Return a set of keywords for parameters that should not use weight decay.\"\"\"\n return {\"attention_biases\"}\n\n def forward_features(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Process input through feature extraction layers, returning spatial features.\"\"\"\n x = self.patch_embed(x) # x input is (N, C, H, W)\n\n x = self.layers[0](x)\n start_i = 1\n\n for i in range(start_i, len(self.layers)):\n layer = self.layers[i]\n x = layer(x)\n batch, _, channel = x.shape\n x = x.view(batch, self.patches_resolution[0] // 4, self.patches_resolution[1] // 4, channel)\n x = x.permute(0, 3, 1, 2)\n return self.neck(x)\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Perform the forward pass through the TinyViT model, extracting features from the input image.\"\"\"\n return self.forward_features(x)\n\n def set_imgsz(self, imgsz: List[int] = [1024, 1024]):\n \"\"\"Set image size to make model compatible with different image sizes.\"\"\"\n imgsz = [s // 4 for s in imgsz]\n self.patches_resolution = imgsz\n for i, layer in enumerate(self.layers):\n input_resolution = (\n imgsz[0] // (2 ** (i - 1 if i == 3 else i)),\n imgsz[1] // (2 ** (i - 1 if i == 3 else i)),\n )\n layer.input_resolution = input_resolution\n if layer.downsample is not None:\n layer.downsample.input_resolution = input_resolution\n if isinstance(layer, BasicLayer):\n for b in layer.blocks:\n b.input_resolution = input_resolution", "chunk_type": "class", "name": "TinyViT", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\tiny_encoder.py", "start_line": 769, "end_line": 997, "start_col": 0, "end_col": 57, "parent_name": null, "docstring": "TinyViT: A compact vision transformer architecture for efficient image classification and feature extraction.\n\nThis class implements the TinyViT model, which combines elements of vision transformers and convolutional\nneural networks for improved efficiency and performance on vision tasks. It features hierarchical processing\nwith patch embedding, multiple stages of attention and convolution blocks, and a feature refinement neck.\n\nAttributes:\n img_size (int): Input image size.\n num_classes (int): Number of classification classes.\n depths (Tuple[int, int, int, int]): Number of blocks in each stage.\n num_layers (int): Total number of layers in the network.\n mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension.\n patch_embed (PatchEmbed): Module for patch embedding.\n patches_resolution (Tuple[int, int]): Resolution of embedded patches.\n layers (nn.ModuleList): List of network layers.\n norm_head (nn.LayerNorm): Layer normalization for the classifier head.\n head (nn.Linear): Linear layer for final classification.\n neck (nn.Sequential): Neck module for feature refinement.\n\nExamples:\n >>> model = TinyViT(img_size=224, num_classes=1000)\n >>> x = torch.randn(1, 3, 224, 224)\n >>> features = model.forward_features(x)\n >>> print(features.shape)\n torch.Size([1, 256, 56, 56])", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "itertools", "typing.List", "typing.Optional", "typing.Tuple", "typing.Union", "torch", "torch.nn", "torch.nn.functional", "ultralytics.nn.modules.LayerNorm2d", "ultralytics.utils.instance.to_2tuple", "nn.Module" ], "chunk_id": "class_TinyViT_bdde4483" }, { "content": "import math", "chunk_type": "import", "name": "math", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\transformer.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 11, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_math_1a0c7fce" }, { "content": "from typing import Tuple, Type", "chunk_type": "import", "name": "Tuple, Type", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\transformer.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 30, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Tuple, Type_20611453" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\transformer.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_ed6be518" }, { "content": "from torch import Tensor, nn", "chunk_type": "import", "name": "Tensor, nn", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\transformer.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 28, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Tensor, nn_ba6ccb03" }, { "content": "from ultralytics.nn.modules import MLPBlock", "chunk_type": "import", "name": "MLPBlock", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\transformer.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 43, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_MLPBlock_8903c4ea" }, { "content": "class TwoWayTransformer(nn.Module):\n \"\"\"\n A Two-Way Transformer module for simultaneous attention to image and query points.\n\n This class implements a specialized transformer decoder that attends to an input image using queries with\n supplied positional embeddings. It's useful for tasks like object detection, image segmentation, and point\n cloud processing.\n\n Attributes:\n depth (int): Number of layers in the transformer.\n embedding_dim (int): Channel dimension for input embeddings.\n num_heads (int): Number of heads for multihead attention.\n mlp_dim (int): Internal channel dimension for the MLP block.\n layers (nn.ModuleList): List of TwoWayAttentionBlock layers composing the transformer.\n final_attn_token_to_image (Attention): Final attention layer from queries to image.\n norm_final_attn (nn.LayerNorm): Layer normalization applied to final queries.\n\n Methods:\n forward: Process image and point embeddings through the transformer.\n\n Examples:\n >>> transformer = TwoWayTransformer(depth=6, embedding_dim=256, num_heads=8, mlp_dim=2048)\n >>> image_embedding = torch.randn(1, 256, 32, 32)\n >>> image_pe = torch.randn(1, 256, 32, 32)\n >>> point_embedding = torch.randn(1, 100, 256)\n >>> output_queries, output_image = transformer(image_embedding, image_pe, point_embedding)\n >>> print(output_queries.shape, output_image.shape)\n \"\"\"\n\n def __init__(\n self,\n depth: int,\n embedding_dim: int,\n num_heads: int,\n mlp_dim: int,\n activation: Type[nn.Module] = nn.ReLU,\n attention_downsample_rate: int = 2,\n ) -> None:\n \"\"\"\n Initialize a Two-Way Transformer for simultaneous attention to image and query points.\n\n Args:\n depth (int): Number of layers in the transformer.\n embedding_dim (int): Channel dimension for input embeddings.\n num_heads (int): Number of heads for multihead attention. Must divide embedding_dim.\n mlp_dim (int): Internal channel dimension for the MLP block.\n activation (Type[nn.Module], optional): Activation function to use in the MLP block.\n attention_downsample_rate (int, optional): Downsampling rate for attention mechanism.\n \"\"\"\n super().__init__()\n self.depth = depth\n self.embedding_dim = embedding_dim\n self.num_heads = num_heads\n self.mlp_dim = mlp_dim\n self.layers = nn.ModuleList()\n\n for i in range(depth):\n self.layers.append(\n TwoWayAttentionBlock(\n embedding_dim=embedding_dim,\n num_heads=num_heads,\n mlp_dim=mlp_dim,\n activation=activation,\n attention_downsample_rate=attention_downsample_rate,\n skip_first_layer_pe=(i == 0),\n )\n )\n\n self.final_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)\n self.norm_final_attn = nn.LayerNorm(embedding_dim)\n\n def forward(\n self,\n image_embedding: torch.Tensor,\n image_pe: torch.Tensor,\n point_embedding: torch.Tensor,\n ) -> Tuple[torch.Tensor, torch.Tensor]:\n \"\"\"\n Process image and point embeddings through the Two-Way Transformer.\n\n Args:\n image_embedding (torch.Tensor): Image to attend to, with shape (B, embedding_dim, H, W).\n image_pe (torch.Tensor): Positional encoding to add to the image, with same shape as image_embedding.\n point_embedding (torch.Tensor): Embedding to add to query points, with shape (B, N_points, embedding_dim).\n\n Returns:\n queries (torch.Tensor): Processed point embeddings with shape (B, N_points, embedding_dim).\n keys (torch.Tensor): Processed image embeddings with shape (B, H*W, embedding_dim).\n \"\"\"\n # BxCxHxW -> BxHWxC == B x N_image_tokens x C\n image_embedding = image_embedding.flatten(2).permute(0, 2, 1)\n image_pe = image_pe.flatten(2).permute(0, 2, 1)\n\n # Prepare queries\n queries = point_embedding\n keys = image_embedding\n\n # Apply transformer blocks and final layernorm\n for layer in self.layers:\n queries, keys = layer(\n queries=queries,\n keys=keys,\n query_pe=point_embedding,\n key_pe=image_pe,\n )\n\n # Apply the final attention layer from the points to the image\n q = queries + point_embedding\n k = keys + image_pe\n attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)\n queries = queries + attn_out\n queries = self.norm_final_attn(queries)\n\n return queries, keys", "chunk_type": "class", "name": "TwoWayTransformer", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\transformer.py", "start_line": 12, "end_line": 125, "start_col": 0, "end_col": 28, "parent_name": null, "docstring": "A Two-Way Transformer module for simultaneous attention to image and query points.\n\nThis class implements a specialized transformer decoder that attends to an input image using queries with\nsupplied positional embeddings. It's useful for tasks like object detection, image segmentation, and point\ncloud processing.\n\nAttributes:\n depth (int): Number of layers in the transformer.\n embedding_dim (int): Channel dimension for input embeddings.\n num_heads (int): Number of heads for multihead attention.\n mlp_dim (int): Internal channel dimension for the MLP block.\n layers (nn.ModuleList): List of TwoWayAttentionBlock layers composing the transformer.\n final_attn_token_to_image (Attention): Final attention layer from queries to image.\n norm_final_attn (nn.LayerNorm): Layer normalization applied to final queries.\n\nMethods:\n forward: Process image and point embeddings through the transformer.\n\nExamples:\n >>> transformer = TwoWayTransformer(depth=6, embedding_dim=256, num_heads=8, mlp_dim=2048)\n >>> image_embedding = torch.randn(1, 256, 32, 32)\n >>> image_pe = torch.randn(1, 256, 32, 32)\n >>> point_embedding = torch.randn(1, 100, 256)\n >>> output_queries, output_image = transformer(image_embedding, image_pe, point_embedding)\n >>> print(output_queries.shape, output_image.shape)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "math", "typing.Tuple", "typing.Type", "torch", "torch.Tensor", "torch.nn", "ultralytics.nn.modules.MLPBlock", "nn.Module" ], "chunk_id": "class_TwoWayTransformer_15517ad5" }, { "content": "class TwoWayAttentionBlock(nn.Module):\n \"\"\"\n A two-way attention block for simultaneous attention to image and query points.\n\n This class implements a specialized transformer block with four main layers: self-attention on sparse inputs,\n cross-attention of sparse inputs to dense inputs, MLP block on sparse inputs, and cross-attention of dense\n inputs to sparse inputs.\n\n Attributes:\n self_attn (Attention): Self-attention layer for queries.\n norm1 (nn.LayerNorm): Layer normalization after self-attention.\n cross_attn_token_to_image (Attention): Cross-attention layer from queries to keys.\n norm2 (nn.LayerNorm): Layer normalization after token-to-image attention.\n mlp (MLPBlock): MLP block for transforming query embeddings.\n norm3 (nn.LayerNorm): Layer normalization after MLP block.\n norm4 (nn.LayerNorm): Layer normalization after image-to-token attention.\n cross_attn_image_to_token (Attention): Cross-attention layer from keys to queries.\n skip_first_layer_pe (bool): Whether to skip positional encoding in the first layer.\n\n Methods:\n forward: Apply self-attention and cross-attention to queries and keys.\n\n Examples:\n >>> embedding_dim, num_heads = 256, 8\n >>> block = TwoWayAttentionBlock(embedding_dim, num_heads)\n >>> queries = torch.randn(1, 100, embedding_dim)\n >>> keys = torch.randn(1, 1000, embedding_dim)\n >>> query_pe = torch.randn(1, 100, embedding_dim)\n >>> key_pe = torch.randn(1, 1000, embedding_dim)\n >>> processed_queries, processed_keys = block(queries, keys, query_pe, key_pe)\n \"\"\"\n\n def __init__(\n self,\n embedding_dim: int,\n num_heads: int,\n mlp_dim: int = 2048,\n activation: Type[nn.Module] = nn.ReLU,\n attention_downsample_rate: int = 2,\n skip_first_layer_pe: bool = False,\n ) -> None:\n \"\"\"\n Initialize a TwoWayAttentionBlock for simultaneous attention to image and query points.\n\n This block implements a specialized transformer layer with four main components: self-attention on sparse\n inputs, cross-attention of sparse inputs to dense inputs, MLP block on sparse inputs, and cross-attention\n of dense inputs to sparse inputs.\n\n Args:\n embedding_dim (int): Channel dimension of the embeddings.\n num_heads (int): Number of attention heads in the attention layers.\n mlp_dim (int, optional): Hidden dimension of the MLP block.\n activation (Type[nn.Module], optional): Activation function for the MLP block.\n attention_downsample_rate (int, optional): Downsampling rate for the attention mechanism.\n skip_first_layer_pe (bool, optional): Whether to skip positional encoding in the first layer.\n \"\"\"\n super().__init__()\n self.self_attn = Attention(embedding_dim, num_heads)\n self.norm1 = nn.LayerNorm(embedding_dim)\n\n self.cross_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)\n self.norm2 = nn.LayerNorm(embedding_dim)\n\n self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)\n self.norm3 = nn.LayerNorm(embedding_dim)\n\n self.norm4 = nn.LayerNorm(embedding_dim)\n self.cross_attn_image_to_token = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)\n\n self.skip_first_layer_pe = skip_first_layer_pe\n\n def forward(\n self, queries: torch.Tensor, keys: torch.Tensor, query_pe: torch.Tensor, key_pe: torch.Tensor\n ) -> Tuple[torch.Tensor, torch.Tensor]:\n \"\"\"\n Apply two-way attention to process query and key embeddings in a transformer block.\n\n Args:\n queries (torch.Tensor): Query embeddings with shape (B, N_queries, embedding_dim).\n keys (torch.Tensor): Key embeddings with shape (B, N_keys, embedding_dim).\n query_pe (torch.Tensor): Positional encodings for queries with same shape as queries.\n key_pe (torch.Tensor): Positional encodings for keys with same shape as keys.\n\n Returns:\n queries (torch.Tensor): Processed query embeddings with shape (B, N_queries, embedding_dim).\n keys (torch.Tensor): Processed key embeddings with shape (B, N_keys, embedding_dim).\n \"\"\"\n # Self attention block\n if self.skip_first_layer_pe:\n queries = self.self_attn(q=queries, k=queries, v=queries)\n else:\n q = queries + query_pe\n attn_out = self.self_attn(q=q, k=q, v=queries)\n queries = queries + attn_out\n queries = self.norm1(queries)\n\n # Cross attention block, tokens attending to image embedding\n q = queries + query_pe\n k = keys + key_pe\n attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)\n queries = queries + attn_out\n queries = self.norm2(queries)\n\n # MLP block\n mlp_out = self.mlp(queries)\n queries = queries + mlp_out\n queries = self.norm3(queries)\n\n # Cross attention block, image embedding attending to tokens\n q = queries + query_pe\n k = keys + key_pe\n attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)\n keys = keys + attn_out\n keys = self.norm4(keys)\n\n return queries, keys", "chunk_type": "class", "name": "TwoWayAttentionBlock", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\transformer.py", "start_line": 128, "end_line": 243, "start_col": 0, "end_col": 28, "parent_name": null, "docstring": "A two-way attention block for simultaneous attention to image and query points.\n\nThis class implements a specialized transformer block with four main layers: self-attention on sparse inputs,\ncross-attention of sparse inputs to dense inputs, MLP block on sparse inputs, and cross-attention of dense\ninputs to sparse inputs.\n\nAttributes:\n self_attn (Attention): Self-attention layer for queries.\n norm1 (nn.LayerNorm): Layer normalization after self-attention.\n cross_attn_token_to_image (Attention): Cross-attention layer from queries to keys.\n norm2 (nn.LayerNorm): Layer normalization after token-to-image attention.\n mlp (MLPBlock): MLP block for transforming query embeddings.\n norm3 (nn.LayerNorm): Layer normalization after MLP block.\n norm4 (nn.LayerNorm): Layer normalization after image-to-token attention.\n cross_attn_image_to_token (Attention): Cross-attention layer from keys to queries.\n skip_first_layer_pe (bool): Whether to skip positional encoding in the first layer.\n\nMethods:\n forward: Apply self-attention and cross-attention to queries and keys.\n\nExamples:\n >>> embedding_dim, num_heads = 256, 8\n >>> block = TwoWayAttentionBlock(embedding_dim, num_heads)\n >>> queries = torch.randn(1, 100, embedding_dim)\n >>> keys = torch.randn(1, 1000, embedding_dim)\n >>> query_pe = torch.randn(1, 100, embedding_dim)\n >>> key_pe = torch.randn(1, 1000, embedding_dim)\n >>> processed_queries, processed_keys = block(queries, keys, query_pe, key_pe)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "math", "typing.Tuple", "typing.Type", "torch", "torch.Tensor", "torch.nn", "ultralytics.nn.modules.MLPBlock", "nn.Module" ], "chunk_id": "class_TwoWayAttentionBlock_851bcd69" }, { "content": "class Attention(nn.Module):\n \"\"\"\n An attention layer with downscaling capability for embedding size after projection.\n\n This class implements a multi-head attention mechanism with the option to downsample the internal\n dimension of queries, keys, and values.\n\n Attributes:\n embedding_dim (int): Dimensionality of input embeddings.\n kv_in_dim (int): Dimensionality of key and value inputs.\n internal_dim (int): Internal dimension after downsampling.\n num_heads (int): Number of attention heads.\n q_proj (nn.Linear): Linear projection for queries.\n k_proj (nn.Linear): Linear projection for keys.\n v_proj (nn.Linear): Linear projection for values.\n out_proj (nn.Linear): Linear projection for output.\n\n Methods:\n _separate_heads: Separate input tensor into attention heads.\n _recombine_heads: Recombine separated attention heads.\n forward: Compute attention output for given query, key, and value tensors.\n\n Examples:\n >>> attn = Attention(embedding_dim=256, num_heads=8, downsample_rate=2)\n >>> q = torch.randn(1, 100, 256)\n >>> k = v = torch.randn(1, 50, 256)\n >>> output = attn(q, k, v)\n >>> print(output.shape)\n torch.Size([1, 100, 256])\n \"\"\"\n\n def __init__(\n self,\n embedding_dim: int,\n num_heads: int,\n downsample_rate: int = 1,\n kv_in_dim: int = None,\n ) -> None:\n \"\"\"\n Initialize the Attention module with specified dimensions and settings.\n\n Args:\n embedding_dim (int): Dimensionality of input embeddings.\n num_heads (int): Number of attention heads.\n downsample_rate (int, optional): Factor by which internal dimensions are downsampled.\n kv_in_dim (int | None, optional): Dimensionality of key and value inputs. If None, uses embedding_dim.\n\n Raises:\n AssertionError: If num_heads does not evenly divide the internal dim (embedding_dim / downsample_rate).\n \"\"\"\n super().__init__()\n self.embedding_dim = embedding_dim\n self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim\n self.internal_dim = embedding_dim // downsample_rate\n self.num_heads = num_heads\n assert self.internal_dim % num_heads == 0, \"num_heads must divide embedding_dim.\"\n\n self.q_proj = nn.Linear(embedding_dim, self.internal_dim)\n self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim)\n self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim)\n self.out_proj = nn.Linear(self.internal_dim, embedding_dim)\n\n @staticmethod\n def _separate_heads(x: torch.Tensor, num_heads: int) -> torch.Tensor:\n \"\"\"Separate the input tensor into the specified number of attention heads.\"\"\"\n b, n, c = x.shape\n x = x.reshape(b, n, num_heads, c // num_heads)\n return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head\n\n @staticmethod\n def _recombine_heads(x: Tensor) -> Tensor:\n \"\"\"Recombine separated attention heads into a single tensor.\"\"\"\n b, n_heads, n_tokens, c_per_head = x.shape\n x = x.transpose(1, 2)\n return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C\n\n def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:\n \"\"\"\n Apply multi-head attention to query, key, and value tensors with optional downsampling.\n\n Args:\n q (torch.Tensor): Query tensor with shape (B, N_q, embedding_dim).\n k (torch.Tensor): Key tensor with shape (B, N_k, embedding_dim).\n v (torch.Tensor): Value tensor with shape (B, N_k, embedding_dim).\n\n Returns:\n (torch.Tensor): Output tensor after attention with shape (B, N_q, embedding_dim).\n \"\"\"\n # Input projections\n q = self.q_proj(q)\n k = self.k_proj(k)\n v = self.v_proj(v)\n\n # Separate into heads\n q = self._separate_heads(q, self.num_heads)\n k = self._separate_heads(k, self.num_heads)\n v = self._separate_heads(v, self.num_heads)\n\n # Attention\n _, _, _, c_per_head = q.shape\n attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens\n attn = attn / math.sqrt(c_per_head)\n attn = torch.softmax(attn, dim=-1)\n\n # Get output\n out = attn @ v\n out = self._recombine_heads(out)\n return self.out_proj(out)", "chunk_type": "class", "name": "Attention", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\transformer.py", "start_line": 246, "end_line": 353, "start_col": 0, "end_col": 33, "parent_name": null, "docstring": "An attention layer with downscaling capability for embedding size after projection.\n\nThis class implements a multi-head attention mechanism with the option to downsample the internal\ndimension of queries, keys, and values.\n\nAttributes:\n embedding_dim (int): Dimensionality of input embeddings.\n kv_in_dim (int): Dimensionality of key and value inputs.\n internal_dim (int): Internal dimension after downsampling.\n num_heads (int): Number of attention heads.\n q_proj (nn.Linear): Linear projection for queries.\n k_proj (nn.Linear): Linear projection for keys.\n v_proj (nn.Linear): Linear projection for values.\n out_proj (nn.Linear): Linear projection for output.\n\nMethods:\n _separate_heads: Separate input tensor into attention heads.\n _recombine_heads: Recombine separated attention heads.\n forward: Compute attention output for given query, key, and value tensors.\n\nExamples:\n >>> attn = Attention(embedding_dim=256, num_heads=8, downsample_rate=2)\n >>> q = torch.randn(1, 100, 256)\n >>> k = v = torch.randn(1, 50, 256)\n >>> output = attn(q, k, v)\n >>> print(output.shape)\n torch.Size([1, 100, 256])", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "math", "typing.Tuple", "typing.Type", "torch", "torch.Tensor", "torch.nn", "ultralytics.nn.modules.MLPBlock", "nn.Module" ], "chunk_id": "class_Attention_d0036d25" }, { "content": "from typing import Any, Dict, Tuple", "chunk_type": "import", "name": "Any, Dict, Tuple", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\utils.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 35, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Any, Dict, Tuple_80431a72" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\utils.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_7f36154e" }, { "content": "import torch.nn.functional as F", "chunk_type": "import", "name": "torch.nn.functional", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\utils.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 31, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch.nn.functional_659e6c2d" }, { "content": "def select_closest_cond_frames(frame_idx: int, cond_frame_outputs: Dict[int, Any], max_cond_frame_num: int):\n \"\"\"\n Select the closest conditioning frames to a given frame index.\n\n Args:\n frame_idx (int): Current frame index.\n cond_frame_outputs (Dict[int, Any]): Dictionary of conditioning frame outputs keyed by frame indices.\n max_cond_frame_num (int): Maximum number of conditioning frames to select.\n\n Returns:\n selected_outputs (Dict[int, Any]): Selected items from cond_frame_outputs.\n unselected_outputs (Dict[int, Any]): Items not selected from cond_frame_outputs.\n\n Examples:\n >>> frame_idx = 5\n >>> cond_frame_outputs = {1: \"a\", 3: \"b\", 7: \"c\", 9: \"d\"}\n >>> max_cond_frame_num = 2\n >>> selected, unselected = select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num)\n >>> print(selected)\n {3: 'b', 7: 'c'}\n >>> print(unselected)\n {1: 'a', 9: 'd'}\n \"\"\"\n if max_cond_frame_num == -1 or len(cond_frame_outputs) <= max_cond_frame_num:\n selected_outputs = cond_frame_outputs\n unselected_outputs = {}\n else:\n assert max_cond_frame_num >= 2, \"we should allow using 2+ conditioning frames\"\n selected_outputs = {}\n\n # The closest conditioning frame before `frame_idx` (if any)\n idx_before = max((t for t in cond_frame_outputs if t < frame_idx), default=None)\n if idx_before is not None:\n selected_outputs[idx_before] = cond_frame_outputs[idx_before]\n\n # The closest conditioning frame after `frame_idx` (if any)\n idx_after = min((t for t in cond_frame_outputs if t >= frame_idx), default=None)\n if idx_after is not None:\n selected_outputs[idx_after] = cond_frame_outputs[idx_after]\n\n # Add other temporally closest conditioning frames until reaching a total\n # of `max_cond_frame_num` conditioning frames.\n num_remain = max_cond_frame_num - len(selected_outputs)\n inds_remain = sorted(\n (t for t in cond_frame_outputs if t not in selected_outputs),\n key=lambda x: abs(x - frame_idx),\n )[:num_remain]\n selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain)\n unselected_outputs = {t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs}\n\n return selected_outputs, unselected_outputs", "chunk_type": "function", "name": "select_closest_cond_frames", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\utils.py", "start_line": 9, "end_line": 59, "start_col": 0, "end_col": 47, "parent_name": null, "docstring": "Select the closest conditioning frames to a given frame index.\n\nArgs:\n frame_idx (int): Current frame index.\n cond_frame_outputs (Dict[int, Any]): Dictionary of conditioning frame outputs keyed by frame indices.\n max_cond_frame_num (int): Maximum number of conditioning frames to select.\n\nReturns:\n selected_outputs (Dict[int, Any]): Selected items from cond_frame_outputs.\n unselected_outputs (Dict[int, Any]): Items not selected from cond_frame_outputs.\n\nExamples:\n >>> frame_idx = 5\n >>> cond_frame_outputs = {1: \"a\", 3: \"b\", 7: \"c\", 9: \"d\"}\n >>> max_cond_frame_num = 2\n >>> selected, unselected = select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num)\n >>> print(selected)\n {3: 'b', 7: 'c'}\n >>> print(unselected)\n {1: 'a', 9: 'd'}", "parameters": [ "frame_idx: int", "cond_frame_outputs: Dict[int, Any]", "max_cond_frame_num: int" ], "return_type": null, "decorators": [], "complexity_score": 9, "dependencies": [ "typing.Any", "typing.Dict", "typing.Tuple", "torch", "torch.nn.functional" ], "chunk_id": "function_select_closest_cond_frames_f98c91cb" }, { "content": "def get_1d_sine_pe(pos_inds: torch.Tensor, dim: int, temperature: float = 10000):\n \"\"\"\n Generate 1D sinusoidal positional embeddings for given positions and dimensions.\n\n Args:\n pos_inds (torch.Tensor): Position indices for which to generate embeddings.\n dim (int): Dimension of the positional embeddings. Should be an even number.\n temperature (float, optional): Scaling factor for the frequency of the sinusoidal functions.\n\n Returns:\n (torch.Tensor): Sinusoidal positional embeddings with shape (pos_inds.shape, dim).\n\n Examples:\n >>> pos = torch.tensor([0, 1, 2, 3])\n >>> embeddings = get_1d_sine_pe(pos, 128)\n >>> embeddings.shape\n torch.Size([4, 128])\n \"\"\"\n pe_dim = dim // 2\n dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device)\n dim_t = temperature ** (2 * (dim_t // 2) / pe_dim)\n\n pos_embed = pos_inds.unsqueeze(-1) / dim_t\n pos_embed = torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1)\n return pos_embed", "chunk_type": "function", "name": "get_1d_sine_pe", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\utils.py", "start_line": 62, "end_line": 86, "start_col": 0, "end_col": 20, "parent_name": null, "docstring": "Generate 1D sinusoidal positional embeddings for given positions and dimensions.\n\nArgs:\n pos_inds (torch.Tensor): Position indices for which to generate embeddings.\n dim (int): Dimension of the positional embeddings. Should be an even number.\n temperature (float, optional): Scaling factor for the frequency of the sinusoidal functions.\n\nReturns:\n (torch.Tensor): Sinusoidal positional embeddings with shape (pos_inds.shape, dim).\n\nExamples:\n >>> pos = torch.tensor([0, 1, 2, 3])\n >>> embeddings = get_1d_sine_pe(pos, 128)\n >>> embeddings.shape\n torch.Size([4, 128])", "parameters": [ "pos_inds: torch.Tensor", "dim: int", "temperature: float" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "typing.Any", "typing.Dict", "typing.Tuple", "torch", "torch.nn.functional" ], "chunk_id": "function_get_1d_sine_pe_0de510ca" }, { "content": "def init_t_xy(end_x: int, end_y: int):\n \"\"\"\n Initialize 1D and 2D coordinate tensors for a grid of specified dimensions.\n\n This function creates coordinate tensors for a grid with dimensions end_x × end_y. It generates a linear index tensor\n and corresponding x and y coordinate tensors.\n\n Args:\n end_x (int): Width of the grid (number of columns).\n end_y (int): Height of the grid (number of rows).\n\n Returns:\n t_x (torch.Tensor): X-coordinates for each position, with shape (end_x * end_y).\n t_y (torch.Tensor): Y-coordinates for each position, with shape (end_x * end_y).\n\n Examples:\n >>> t_x, t_y = init_t_xy(3, 2)\n >>> print(t_x)\n tensor([0., 1., 2., 0., 1., 2.])\n >>> print(t_y)\n tensor([0., 0., 0., 1., 1., 1.])\n \"\"\"\n t = torch.arange(end_x * end_y, dtype=torch.float32)\n t_x = (t % end_x).float()\n t_y = torch.div(t, end_x, rounding_mode=\"floor\").float()\n return t_x, t_y", "chunk_type": "function", "name": "init_t_xy", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\utils.py", "start_line": 89, "end_line": 114, "start_col": 0, "end_col": 19, "parent_name": null, "docstring": "Initialize 1D and 2D coordinate tensors for a grid of specified dimensions.\n\nThis function creates coordinate tensors for a grid with dimensions end_x × end_y. It generates a linear index tensor\nand corresponding x and y coordinate tensors.\n\nArgs:\n end_x (int): Width of the grid (number of columns).\n end_y (int): Height of the grid (number of rows).\n\nReturns:\n t_x (torch.Tensor): X-coordinates for each position, with shape (end_x * end_y).\n t_y (torch.Tensor): Y-coordinates for each position, with shape (end_x * end_y).\n\nExamples:\n >>> t_x, t_y = init_t_xy(3, 2)\n >>> print(t_x)\n tensor([0., 1., 2., 0., 1., 2.])\n >>> print(t_y)\n tensor([0., 0., 0., 1., 1., 1.])", "parameters": [ "end_x: int", "end_y: int" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "typing.Any", "typing.Dict", "typing.Tuple", "torch", "torch.nn.functional" ], "chunk_id": "function_init_t_xy_c89e4fba" }, { "content": "def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):\n \"\"\"\n Compute axial complex exponential positional encodings for 2D spatial positions in a grid.\n\n This function generates complex exponential positional encodings for a 2D grid of spatial positions,\n using separate frequency components for the x and y dimensions.\n\n Args:\n dim (int): Dimension of the positional encoding.\n end_x (int): Width of the 2D grid.\n end_y (int): Height of the 2D grid.\n theta (float, optional): Scaling factor for frequency computation.\n\n Returns:\n (torch.Tensor): Complex exponential positional encodings with shape (end_x*end_y, dim//2).\n\n Examples:\n >>> dim, end_x, end_y = 128, 8, 8\n >>> freqs_cis = compute_axial_cis(dim, end_x, end_y)\n >>> freqs_cis.shape\n torch.Size([64, 64])\n \"\"\"\n freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))\n freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))\n\n t_x, t_y = init_t_xy(end_x, end_y)\n freqs_x = torch.outer(t_x, freqs_x)\n freqs_y = torch.outer(t_y, freqs_y)\n freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)\n freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y)\n return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1)", "chunk_type": "function", "name": "compute_axial_cis", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\utils.py", "start_line": 117, "end_line": 147, "start_col": 0, "end_col": 56, "parent_name": null, "docstring": "Compute axial complex exponential positional encodings for 2D spatial positions in a grid.\n\nThis function generates complex exponential positional encodings for a 2D grid of spatial positions,\nusing separate frequency components for the x and y dimensions.\n\nArgs:\n dim (int): Dimension of the positional encoding.\n end_x (int): Width of the 2D grid.\n end_y (int): Height of the 2D grid.\n theta (float, optional): Scaling factor for frequency computation.\n\nReturns:\n (torch.Tensor): Complex exponential positional encodings with shape (end_x*end_y, dim//2).\n\nExamples:\n >>> dim, end_x, end_y = 128, 8, 8\n >>> freqs_cis = compute_axial_cis(dim, end_x, end_y)\n >>> freqs_cis.shape\n torch.Size([64, 64])", "parameters": [ "dim: int", "end_x: int", "end_y: int", "theta: float" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "typing.Any", "typing.Dict", "typing.Tuple", "torch", "torch.nn.functional" ], "chunk_id": "function_compute_axial_cis_723b9554" }, { "content": "def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):\n \"\"\"\n Reshape frequency tensor for broadcasting with input tensor.\n\n Reshapes a frequency tensor to ensure dimensional compatibility for broadcasting with an input tensor.\n This function is typically used in positional encoding operations.\n\n Args:\n freqs_cis (torch.Tensor): Frequency tensor with shape matching the last two dimensions of x.\n x (torch.Tensor): Input tensor to broadcast with.\n\n Returns:\n (torch.Tensor): Reshaped frequency tensor ready for broadcasting with the input tensor.\n\n Raises:\n AssertionError: If the shape of freqs_cis doesn't match the last two dimensions of x.\n \"\"\"\n ndim = x.ndim\n assert 0 <= 1 < ndim\n assert freqs_cis.shape == (x.shape[-2], x.shape[-1])\n shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)]\n return freqs_cis.view(*shape)", "chunk_type": "function", "name": "reshape_for_broadcast", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\utils.py", "start_line": 150, "end_line": 171, "start_col": 0, "end_col": 33, "parent_name": null, "docstring": "Reshape frequency tensor for broadcasting with input tensor.\n\nReshapes a frequency tensor to ensure dimensional compatibility for broadcasting with an input tensor.\nThis function is typically used in positional encoding operations.\n\nArgs:\n freqs_cis (torch.Tensor): Frequency tensor with shape matching the last two dimensions of x.\n x (torch.Tensor): Input tensor to broadcast with.\n\nReturns:\n (torch.Tensor): Reshaped frequency tensor ready for broadcasting with the input tensor.\n\nRaises:\n AssertionError: If the shape of freqs_cis doesn't match the last two dimensions of x.", "parameters": [ "freqs_cis: torch.Tensor", "x: torch.Tensor" ], "return_type": null, "decorators": [], "complexity_score": 2, "dependencies": [ "typing.Any", "typing.Dict", "typing.Tuple", "torch", "torch.nn.functional" ], "chunk_id": "function_reshape_for_broadcast_c523e5cb" }, { "content": "def apply_rotary_enc(\n xq: torch.Tensor,\n xk: torch.Tensor,\n freqs_cis: torch.Tensor,\n repeat_freqs_k: bool = False,\n):\n \"\"\"\n Apply rotary positional encoding to query and key tensors.\n\n This function applies rotary positional encoding (RoPE) to query and key tensors using complex-valued frequency\n components. RoPE is a technique that injects relative position information into self-attention mechanisms.\n\n Args:\n xq (torch.Tensor): Query tensor to encode with positional information.\n xk (torch.Tensor): Key tensor to encode with positional information.\n freqs_cis (torch.Tensor): Complex-valued frequency components for rotary encoding with shape matching the\n last two dimensions of xq.\n repeat_freqs_k (bool, optional): Whether to repeat frequency components along sequence length dimension\n to match key sequence length.\n\n Returns:\n xq_out (torch.Tensor): Query tensor with rotary positional encoding applied.\n xk_out (torch.Tensor): Key tensor with rotary positional encoding applied, or original xk if xk is empty.\n\n Examples:\n >>> import torch\n >>> xq = torch.randn(2, 8, 16, 64) # [batch, heads, seq_len, dim]\n >>> xk = torch.randn(2, 8, 16, 64)\n >>> freqs_cis = compute_axial_cis(64, 4, 4) # For a 4x4 spatial grid with dim=64\n >>> q_encoded, k_encoded = apply_rotary_enc(xq, xk, freqs_cis)\n \"\"\"\n xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))\n xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) if xk.shape[-2] != 0 else None\n freqs_cis = reshape_for_broadcast(freqs_cis, xq_)\n xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)\n if xk_ is None:\n # No keys to rotate, due to dropout\n return xq_out.type_as(xq).to(xq.device), xk\n # Repeat freqs along seq_len dim to match k seq_len\n if repeat_freqs_k:\n r = xk_.shape[-2] // xq_.shape[-2]\n freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)\n xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)\n return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)", "chunk_type": "function", "name": "apply_rotary_enc", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\utils.py", "start_line": 174, "end_line": 217, "start_col": 0, "end_col": 77, "parent_name": null, "docstring": "Apply rotary positional encoding to query and key tensors.\n\nThis function applies rotary positional encoding (RoPE) to query and key tensors using complex-valued frequency\ncomponents. RoPE is a technique that injects relative position information into self-attention mechanisms.\n\nArgs:\n xq (torch.Tensor): Query tensor to encode with positional information.\n xk (torch.Tensor): Key tensor to encode with positional information.\n freqs_cis (torch.Tensor): Complex-valued frequency components for rotary encoding with shape matching the\n last two dimensions of xq.\n repeat_freqs_k (bool, optional): Whether to repeat frequency components along sequence length dimension\n to match key sequence length.\n\nReturns:\n xq_out (torch.Tensor): Query tensor with rotary positional encoding applied.\n xk_out (torch.Tensor): Key tensor with rotary positional encoding applied, or original xk if xk is empty.\n\nExamples:\n >>> import torch\n >>> xq = torch.randn(2, 8, 16, 64) # [batch, heads, seq_len, dim]\n >>> xk = torch.randn(2, 8, 16, 64)\n >>> freqs_cis = compute_axial_cis(64, 4, 4) # For a 4x4 spatial grid with dim=64\n >>> q_encoded, k_encoded = apply_rotary_enc(xq, xk, freqs_cis)", "parameters": [ "xq: torch.Tensor", "xk: torch.Tensor", "freqs_cis: torch.Tensor", "repeat_freqs_k: bool" ], "return_type": null, "decorators": [], "complexity_score": 3, "dependencies": [ "typing.Any", "typing.Dict", "typing.Tuple", "torch", "torch.nn.functional" ], "chunk_id": "function_apply_rotary_enc_5a2215dd" }, { "content": "def window_partition(x: torch.Tensor, window_size: int):\n \"\"\"\n Partition input tensor into non-overlapping windows with padding if needed.\n\n Args:\n x (torch.Tensor): Input tensor with shape (B, H, W, C).\n window_size (int): Size of each window.\n\n Returns:\n windows (torch.Tensor): Partitioned windows with shape (B * num_windows, window_size, window_size, C).\n padded_h_w (Tuple[int, int]): Padded height and width before partition.\n\n Examples:\n >>> x = torch.randn(1, 16, 16, 3)\n >>> windows, (Hp, Wp) = window_partition(x, window_size=4)\n >>> print(windows.shape, Hp, Wp)\n torch.Size([16, 4, 4, 3]) 16 16\n \"\"\"\n B, H, W, C = x.shape\n\n pad_h = (window_size - H % window_size) % window_size\n pad_w = (window_size - W % window_size) % window_size\n if pad_h > 0 or pad_w > 0:\n x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))\n Hp, Wp = H + pad_h, W + pad_w\n\n x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)\n windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)\n return windows, (Hp, Wp)", "chunk_type": "function", "name": "window_partition", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\utils.py", "start_line": 220, "end_line": 248, "start_col": 0, "end_col": 28, "parent_name": null, "docstring": "Partition input tensor into non-overlapping windows with padding if needed.\n\nArgs:\n x (torch.Tensor): Input tensor with shape (B, H, W, C).\n window_size (int): Size of each window.\n\nReturns:\n windows (torch.Tensor): Partitioned windows with shape (B * num_windows, window_size, window_size, C).\n padded_h_w (Tuple[int, int]): Padded height and width before partition.\n\nExamples:\n >>> x = torch.randn(1, 16, 16, 3)\n >>> windows, (Hp, Wp) = window_partition(x, window_size=4)\n >>> print(windows.shape, Hp, Wp)\n torch.Size([16, 4, 4, 3]) 16 16", "parameters": [ "x: torch.Tensor", "window_size: int" ], "return_type": null, "decorators": [], "complexity_score": 2, "dependencies": [ "typing.Any", "typing.Dict", "typing.Tuple", "torch", "torch.nn.functional" ], "chunk_id": "function_window_partition_c19df743" }, { "content": "def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]):\n \"\"\"\n Unpartition windowed sequences into original sequences and remove padding.\n\n This function reverses the windowing process, reconstructing the original input from windowed segments\n and removing any padding that was added during the windowing process.\n\n Args:\n windows (torch.Tensor): Input tensor of windowed sequences with shape (B * num_windows, window_size,\n window_size, C), where B is the batch size, num_windows is the number of windows, window_size is\n the size of each window, and C is the number of channels.\n window_size (int): Size of each window.\n pad_hw (Tuple[int, int]): Padded height and width (Hp, Wp) of the input before windowing.\n hw (Tuple[int, int]): Original height and width (H, W) of the input before padding and windowing.\n\n Returns:\n (torch.Tensor): Unpartitioned sequences with shape (B, H, W, C), where B is the batch size, H and W\n are the original height and width, and C is the number of channels.\n\n Examples:\n >>> windows = torch.rand(32, 8, 8, 64) # 32 windows of size 8x8 with 64 channels\n >>> pad_hw = (16, 16) # Padded height and width\n >>> hw = (15, 14) # Original height and width\n >>> x = window_unpartition(windows, window_size=8, pad_hw=pad_hw, hw=hw)\n >>> print(x.shape)\n torch.Size([1, 15, 14, 64])\n \"\"\"\n Hp, Wp = pad_hw\n H, W = hw\n B = windows.shape[0] // (Hp * Wp // window_size // window_size)\n x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)\n x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)\n\n if Hp > H or Wp > W:\n x = x[:, :H, :W, :].contiguous()\n return x", "chunk_type": "function", "name": "window_unpartition", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\utils.py", "start_line": 251, "end_line": 286, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": "Unpartition windowed sequences into original sequences and remove padding.\n\nThis function reverses the windowing process, reconstructing the original input from windowed segments\nand removing any padding that was added during the windowing process.\n\nArgs:\n windows (torch.Tensor): Input tensor of windowed sequences with shape (B * num_windows, window_size,\n window_size, C), where B is the batch size, num_windows is the number of windows, window_size is\n the size of each window, and C is the number of channels.\n window_size (int): Size of each window.\n pad_hw (Tuple[int, int]): Padded height and width (Hp, Wp) of the input before windowing.\n hw (Tuple[int, int]): Original height and width (H, W) of the input before padding and windowing.\n\nReturns:\n (torch.Tensor): Unpartitioned sequences with shape (B, H, W, C), where B is the batch size, H and W\n are the original height and width, and C is the number of channels.\n\nExamples:\n >>> windows = torch.rand(32, 8, 8, 64) # 32 windows of size 8x8 with 64 channels\n >>> pad_hw = (16, 16) # Padded height and width\n >>> hw = (15, 14) # Original height and width\n >>> x = window_unpartition(windows, window_size=8, pad_hw=pad_hw, hw=hw)\n >>> print(x.shape)\n torch.Size([1, 15, 14, 64])", "parameters": [ "windows: torch.Tensor", "window_size: int", "pad_hw: Tuple[int, int]", "hw: Tuple[int, int]" ], "return_type": null, "decorators": [], "complexity_score": 2, "dependencies": [ "typing.Any", "typing.Dict", "typing.Tuple", "torch", "torch.nn.functional" ], "chunk_id": "function_window_unpartition_c242b26b" }, { "content": "def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:\n \"\"\"\n Extract relative positional embeddings based on query and key sizes.\n\n Args:\n q_size (int): Size of the query.\n k_size (int): Size of the key.\n rel_pos (torch.Tensor): Relative position embeddings with shape (L, C), where L is the maximum relative\n distance and C is the embedding dimension.\n\n Returns:\n (torch.Tensor): Extracted positional embeddings according to relative positions, with shape (q_size,\n k_size, C).\n\n Examples:\n >>> q_size, k_size = 8, 16\n >>> rel_pos = torch.randn(31, 64) # 31 = 2 * max(8, 16) - 1\n >>> extracted_pos = get_rel_pos(q_size, k_size, rel_pos)\n >>> print(extracted_pos.shape)\n torch.Size([8, 16, 64])\n \"\"\"\n max_rel_dist = int(2 * max(q_size, k_size) - 1)\n # Interpolate rel pos if needed.\n if rel_pos.shape[0] != max_rel_dist:\n # Interpolate rel pos.\n rel_pos_resized = F.interpolate(\n rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),\n size=max_rel_dist,\n mode=\"linear\",\n )\n rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)\n else:\n rel_pos_resized = rel_pos\n\n # Scale the coords with short length if shapes for q and k are different.\n q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)\n k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)\n relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)\n\n return rel_pos_resized[relative_coords.long()]", "chunk_type": "function", "name": "get_rel_pos", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\utils.py", "start_line": 289, "end_line": 328, "start_col": 0, "end_col": 50, "parent_name": null, "docstring": "Extract relative positional embeddings based on query and key sizes.\n\nArgs:\n q_size (int): Size of the query.\n k_size (int): Size of the key.\n rel_pos (torch.Tensor): Relative position embeddings with shape (L, C), where L is the maximum relative\n distance and C is the embedding dimension.\n\nReturns:\n (torch.Tensor): Extracted positional embeddings according to relative positions, with shape (q_size,\n k_size, C).\n\nExamples:\n >>> q_size, k_size = 8, 16\n >>> rel_pos = torch.randn(31, 64) # 31 = 2 * max(8, 16) - 1\n >>> extracted_pos = get_rel_pos(q_size, k_size, rel_pos)\n >>> print(extracted_pos.shape)\n torch.Size([8, 16, 64])", "parameters": [ "q_size: int", "k_size: int", "rel_pos: torch.Tensor" ], "return_type": "torch.Tensor", "decorators": [], "complexity_score": 2, "dependencies": [ "typing.Any", "typing.Dict", "typing.Tuple", "torch", "torch.nn.functional" ], "chunk_id": "function_get_rel_pos_ab6a89b4" }, { "content": "def add_decomposed_rel_pos(\n attn: torch.Tensor,\n q: torch.Tensor,\n rel_pos_h: torch.Tensor,\n rel_pos_w: torch.Tensor,\n q_size: Tuple[int, int],\n k_size: Tuple[int, int],\n) -> torch.Tensor:\n \"\"\"\n Add decomposed Relative Positional Embeddings to the attention map.\n\n This function calculates and applies decomposed Relative Positional Embeddings as described in the MVITv2\n paper. It enhances the attention mechanism by incorporating spatial relationships between query and key\n positions.\n\n Args:\n attn (torch.Tensor): Attention map with shape (B, q_h * q_w, k_h * k_w).\n q (torch.Tensor): Query tensor in the attention layer with shape (B, q_h * q_w, C).\n rel_pos_h (torch.Tensor): Relative position embeddings for height axis with shape (Lh, C).\n rel_pos_w (torch.Tensor): Relative position embeddings for width axis with shape (Lw, C).\n q_size (Tuple[int, int]): Spatial sequence size of query q as (q_h, q_w).\n k_size (Tuple[int, int]): Spatial sequence size of key k as (k_h, k_w).\n\n Returns:\n (torch.Tensor): Updated attention map with added relative positional embeddings, shape\n (B, q_h * q_w, k_h * k_w).\n\n Examples:\n >>> B, C, q_h, q_w, k_h, k_w = 1, 64, 8, 8, 8, 8\n >>> attn = torch.rand(B, q_h * q_w, k_h * k_w)\n >>> q = torch.rand(B, q_h * q_w, C)\n >>> rel_pos_h = torch.rand(2 * max(q_h, k_h) - 1, C)\n >>> rel_pos_w = torch.rand(2 * max(q_w, k_w) - 1, C)\n >>> q_size, k_size = (q_h, q_w), (k_h, k_w)\n >>> updated_attn = add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size)\n >>> print(updated_attn.shape)\n torch.Size([1, 64, 64])\n\n References:\n https://github.com/facebookresearch/mvit/blob/main/mvit/models/attention.py\n \"\"\"\n q_h, q_w = q_size\n k_h, k_w = k_size\n Rh = get_rel_pos(q_h, k_h, rel_pos_h)\n Rw = get_rel_pos(q_w, k_w, rel_pos_w)\n\n B, _, dim = q.shape\n r_q = q.reshape(B, q_h, q_w, dim)\n rel_h = torch.einsum(\"bhwc,hkc->bhwk\", r_q, Rh)\n rel_w = torch.einsum(\"bhwc,wkc->bhwk\", r_q, Rw)\n\n attn = (attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view(\n B, q_h * q_w, k_h * k_w\n )\n\n return attn", "chunk_type": "function", "name": "add_decomposed_rel_pos", "file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\utils.py", "start_line": 331, "end_line": 386, "start_col": 0, "end_col": 15, "parent_name": null, "docstring": "Add decomposed Relative Positional Embeddings to the attention map.\n\nThis function calculates and applies decomposed Relative Positional Embeddings as described in the MVITv2\npaper. It enhances the attention mechanism by incorporating spatial relationships between query and key\npositions.\n\nArgs:\n attn (torch.Tensor): Attention map with shape (B, q_h * q_w, k_h * k_w).\n q (torch.Tensor): Query tensor in the attention layer with shape (B, q_h * q_w, C).\n rel_pos_h (torch.Tensor): Relative position embeddings for height axis with shape (Lh, C).\n rel_pos_w (torch.Tensor): Relative position embeddings for width axis with shape (Lw, C).\n q_size (Tuple[int, int]): Spatial sequence size of query q as (q_h, q_w).\n k_size (Tuple[int, int]): Spatial sequence size of key k as (k_h, k_w).\n\nReturns:\n (torch.Tensor): Updated attention map with added relative positional embeddings, shape\n (B, q_h * q_w, k_h * k_w).\n\nExamples:\n >>> B, C, q_h, q_w, k_h, k_w = 1, 64, 8, 8, 8, 8\n >>> attn = torch.rand(B, q_h * q_w, k_h * k_w)\n >>> q = torch.rand(B, q_h * q_w, C)\n >>> rel_pos_h = torch.rand(2 * max(q_h, k_h) - 1, C)\n >>> rel_pos_w = torch.rand(2 * max(q_w, k_w) - 1, C)\n >>> q_size, k_size = (q_h, q_w), (k_h, k_w)\n >>> updated_attn = add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size)\n >>> print(updated_attn.shape)\n torch.Size([1, 64, 64])\n\nReferences:\n https://github.com/facebookresearch/mvit/blob/main/mvit/models/attention.py", "parameters": [ "attn: torch.Tensor", "q: torch.Tensor", "rel_pos_h: torch.Tensor", "rel_pos_w: torch.Tensor", "q_size: Tuple[int, int]", "k_size: Tuple[int, int]" ], "return_type": "torch.Tensor", "decorators": [], "complexity_score": 1, "dependencies": [ "typing.Any", "typing.Dict", "typing.Tuple", "torch", "torch.nn.functional" ], "chunk_id": "function_add_decomposed_rel_pos_ea656b85" }, { "content": "import cv2", "chunk_type": "import", "name": "cv2", "file_path": "ultralytics\\ultralytics\\models\\yolo\\classify\\predict.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 10, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_cv2_0028d373" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\ultralytics\\models\\yolo\\classify\\predict.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_1499b24c" }, { "content": "from PIL import Image", "chunk_type": "import", "name": "Image", "file_path": "ultralytics\\ultralytics\\models\\yolo\\classify\\predict.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 21, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Image_5a9b3a61" }, { "content": "from ultralytics.data.augment import classify_transforms", "chunk_type": "import", "name": "classify_transforms", "file_path": "ultralytics\\ultralytics\\models\\yolo\\classify\\predict.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 56, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_classify_transforms_47decfd8" }, { "content": "from ultralytics.engine.predictor import BasePredictor", "chunk_type": "import", "name": "BasePredictor", "file_path": "ultralytics\\ultralytics\\models\\yolo\\classify\\predict.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 54, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_BasePredictor_637f1b05" }, { "content": "from ultralytics.engine.results import Results", "chunk_type": "import", "name": "Results", "file_path": "ultralytics\\ultralytics\\models\\yolo\\classify\\predict.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 46, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Results_b1491de7" }, { "content": "from ultralytics.utils import DEFAULT_CFG, ops", "chunk_type": "import", "name": "DEFAULT_CFG, ops", "file_path": "ultralytics\\ultralytics\\models\\yolo\\classify\\predict.py", "start_line": 10, "end_line": 10, "start_col": 0, "end_col": 46, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_DEFAULT_CFG, ops_4cdc861c" }, { "content": "class ClassificationPredictor(BasePredictor):\n \"\"\"\n A class extending the BasePredictor class for prediction based on a classification model.\n\n This predictor handles the specific requirements of classification models, including preprocessing images\n and postprocessing predictions to generate classification results.\n\n Attributes:\n args (dict): Configuration arguments for the predictor.\n\n Methods:\n preprocess: Convert input images to model-compatible format.\n postprocess: Process model predictions into Results objects.\n\n Notes:\n - Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.\n\n Examples:\n >>> from ultralytics.utils import ASSETS\n >>> from ultralytics.models.yolo.classify import ClassificationPredictor\n >>> args = dict(model=\"yolo11n-cls.pt\", source=ASSETS)\n >>> predictor = ClassificationPredictor(overrides=args)\n >>> predictor.predict_cli()\n \"\"\"\n\n def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):\n \"\"\"\n Initialize the ClassificationPredictor with the specified configuration and set task to 'classify'.\n\n This constructor initializes a ClassificationPredictor instance, which extends BasePredictor for classification\n tasks. It ensures the task is set to 'classify' regardless of input configuration.\n\n Args:\n cfg (dict): Default configuration dictionary containing prediction settings.\n overrides (dict, optional): Configuration overrides that take precedence over cfg.\n _callbacks (list, optional): List of callback functions to be executed during prediction.\n \"\"\"\n super().__init__(cfg, overrides, _callbacks)\n self.args.task = \"classify\"\n\n def setup_source(self, source):\n \"\"\"Set up source and inference mode and classify transforms.\"\"\"\n super().setup_source(source)\n updated = (\n self.model.model.transforms.transforms[0].size != max(self.imgsz)\n if hasattr(self.model.model, \"transforms\") and hasattr(self.model.model.transforms.transforms[0], \"size\")\n else False\n )\n self.transforms = (\n classify_transforms(self.imgsz) if updated or not self.model.pt else self.model.model.transforms\n )\n\n def preprocess(self, img):\n \"\"\"Convert input images to model-compatible tensor format with appropriate normalization.\"\"\"\n if not isinstance(img, torch.Tensor):\n img = torch.stack(\n [self.transforms(Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))) for im in img], dim=0\n )\n img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device)\n return img.half() if self.model.fp16 else img.float() # Convert uint8 to fp16/32\n\n def postprocess(self, preds, img, orig_imgs):\n \"\"\"\n Process predictions to return Results objects with classification probabilities.\n\n Args:\n preds (torch.Tensor): Raw predictions from the model.\n img (torch.Tensor): Input images after preprocessing.\n orig_imgs (List[np.ndarray] | torch.Tensor): Original images before preprocessing.\n\n Returns:\n (List[Results]): List of Results objects containing classification results for each image.\n \"\"\"\n if not isinstance(orig_imgs, list): # Input images are a torch.Tensor, not a list\n orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)\n\n preds = preds[0] if isinstance(preds, (list, tuple)) else preds\n return [\n Results(orig_img, path=img_path, names=self.model.names, probs=pred)\n for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0])\n ]", "chunk_type": "class", "name": "ClassificationPredictor", "file_path": "ultralytics\\ultralytics\\models\\yolo\\classify\\predict.py", "start_line": 13, "end_line": 93, "start_col": 0, "end_col": 9, "parent_name": null, "docstring": "A class extending the BasePredictor class for prediction based on a classification model.\n\nThis predictor handles the specific requirements of classification models, including preprocessing images\nand postprocessing predictions to generate classification results.\n\nAttributes:\n args (dict): Configuration arguments for the predictor.\n\nMethods:\n preprocess: Convert input images to model-compatible format.\n postprocess: Process model predictions into Results objects.\n\nNotes:\n - Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.\n\nExamples:\n >>> from ultralytics.utils import ASSETS\n >>> from ultralytics.models.yolo.classify import ClassificationPredictor\n >>> args = dict(model=\"yolo11n-cls.pt\", source=ASSETS)\n >>> predictor = ClassificationPredictor(overrides=args)\n >>> predictor.predict_cli()", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "cv2", "torch", "PIL.Image", "ultralytics.data.augment.classify_transforms", "ultralytics.engine.predictor.BasePredictor", "ultralytics.engine.results.Results", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.ops", "BasePredictor" ], "chunk_id": "class_ClassificationPredictor_2e4f6b41" }, { "content": "from copy import copy", "chunk_type": "import", "name": "copy", "file_path": "ultralytics\\ultralytics\\models\\yolo\\classify\\train.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 21, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_copy_6f9c54a3" }, { "content": "from typing import Any, Dict, Optional", "chunk_type": "import", "name": "Any, Dict, Optional", "file_path": "ultralytics\\ultralytics\\models\\yolo\\classify\\train.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 38, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Any, Dict, Optional_433400f4" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\ultralytics\\models\\yolo\\classify\\train.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_d70f3532" }, { "content": "from ultralytics.data import ClassificationDataset, build_dataloader", "chunk_type": "import", "name": "ClassificationDataset, build_dataloader", "file_path": "ultralytics\\ultralytics\\models\\yolo\\classify\\train.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 68, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_ClassificationDataset, build_dataloader_2e41d386" }, { "content": "from ultralytics.engine.trainer import BaseTrainer", "chunk_type": "import", "name": "BaseTrainer", "file_path": "ultralytics\\ultralytics\\models\\yolo\\classify\\train.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 50, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_BaseTrainer_3e625bf0" }, { "content": "from ultralytics.models import yolo", "chunk_type": "import", "name": "yolo", "file_path": "ultralytics\\ultralytics\\models\\yolo\\classify\\train.py", "start_line": 10, "end_line": 10, "start_col": 0, "end_col": 35, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_yolo_08fdf44e" }, { "content": "from ultralytics.nn.tasks import ClassificationModel", "chunk_type": "import", "name": "ClassificationModel", "file_path": "ultralytics\\ultralytics\\models\\yolo\\classify\\train.py", "start_line": 11, "end_line": 11, "start_col": 0, "end_col": 52, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_ClassificationModel_f3fdd470" }, { "content": "from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK", "chunk_type": "import", "name": "DEFAULT_CFG, LOGGER, RANK", "file_path": "ultralytics\\ultralytics\\models\\yolo\\classify\\train.py", "start_line": 12, "end_line": 12, "start_col": 0, "end_col": 55, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_DEFAULT_CFG, LOGGER, RANK_b6bdd43b" }, { "content": "from ultralytics.utils.plotting import plot_images, plot_results", "chunk_type": "import", "name": "plot_images, plot_results", "file_path": "ultralytics\\ultralytics\\models\\yolo\\classify\\train.py", "start_line": 13, "end_line": 13, "start_col": 0, "end_col": 64, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_plot_images, plot_results_5e0ad354" }, { "content": "from ultralytics.utils.torch_utils import is_parallel, strip_optimizer, torch_distributed_zero_first", "chunk_type": "import", "name": "is_parallel, strip_optimizer, torch_distributed_zero_first", "file_path": "ultralytics\\ultralytics\\models\\yolo\\classify\\train.py", "start_line": 14, "end_line": 14, "start_col": 0, "end_col": 100, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_is_parallel, strip_optimizer, torch_distributed_zero_first_058ff9e3" }, { "content": "class ClassificationTrainer(BaseTrainer):\n \"\"\"\n A trainer class extending BaseTrainer for training image classification models.\n\n This trainer handles the training process for image classification tasks, supporting both YOLO classification models\n and torchvision models with comprehensive dataset handling and validation.\n\n Attributes:\n model (ClassificationModel): The classification model to be trained.\n data (Dict[str, Any]): Dictionary containing dataset information including class names and number of classes.\n loss_names (List[str]): Names of the loss functions used during training.\n validator (ClassificationValidator): Validator instance for model evaluation.\n\n Methods:\n set_model_attributes: Set the model's class names from the loaded dataset.\n get_model: Return a modified PyTorch model configured for training.\n setup_model: Load, create or download model for classification.\n build_dataset: Create a ClassificationDataset instance.\n get_dataloader: Return PyTorch DataLoader with transforms for image preprocessing.\n preprocess_batch: Preprocess a batch of images and classes.\n progress_string: Return a formatted string showing training progress.\n get_validator: Return an instance of ClassificationValidator.\n label_loss_items: Return a loss dict with labelled training loss items.\n plot_metrics: Plot metrics from a CSV file.\n final_eval: Evaluate trained model and save validation results.\n plot_training_samples: Plot training samples with their annotations.\n\n Examples:\n Initialize and train a classification model\n >>> from ultralytics.models.yolo.classify import ClassificationTrainer\n >>> args = dict(model=\"yolo11n-cls.pt\", data=\"imagenet10\", epochs=3)\n >>> trainer = ClassificationTrainer(overrides=args)\n >>> trainer.train()\n \"\"\"\n\n def __init__(self, cfg=DEFAULT_CFG, overrides: Optional[Dict[str, Any]] = None, _callbacks=None):\n \"\"\"\n Initialize a ClassificationTrainer object.\n\n This constructor sets up a trainer for image classification tasks, configuring the task type and default\n image size if not specified.\n\n Args:\n cfg (Dict[str, Any], optional): Default configuration dictionary containing training parameters.\n overrides (Dict[str, Any], optional): Dictionary of parameter overrides for the default configuration.\n _callbacks (List[Any], optional): List of callback functions to be executed during training.\n\n Examples:\n Create a trainer with custom configuration\n >>> from ultralytics.models.yolo.classify import ClassificationTrainer\n >>> args = dict(model=\"yolo11n-cls.pt\", data=\"imagenet10\", epochs=3)\n >>> trainer = ClassificationTrainer(overrides=args)\n >>> trainer.train()\n \"\"\"\n if overrides is None:\n overrides = {}\n overrides[\"task\"] = \"classify\"\n if overrides.get(\"imgsz\") is None:\n overrides[\"imgsz\"] = 224\n super().__init__(cfg, overrides, _callbacks)\n\n def set_model_attributes(self):\n \"\"\"Set the YOLO model's class names from the loaded dataset.\"\"\"\n self.model.names = self.data[\"names\"]\n\n def get_model(self, cfg=None, weights=None, verbose: bool = True):\n \"\"\"\n Return a modified PyTorch model configured for training YOLO classification.\n\n Args:\n cfg (Any, optional): Model configuration.\n weights (Any, optional): Pre-trained model weights.\n verbose (bool, optional): Whether to display model information.\n\n Returns:\n (ClassificationModel): Configured PyTorch model for classification.\n \"\"\"\n model = ClassificationModel(cfg, nc=self.data[\"nc\"], ch=self.data[\"channels\"], verbose=verbose and RANK == -1)\n if weights:\n model.load(weights)\n\n for m in model.modules():\n if not self.args.pretrained and hasattr(m, \"reset_parameters\"):\n m.reset_parameters()\n if isinstance(m, torch.nn.Dropout) and self.args.dropout:\n m.p = self.args.dropout # set dropout\n for p in model.parameters():\n p.requires_grad = True # for training\n return model\n\n def setup_model(self):\n \"\"\"\n Load, create or download model for classification tasks.\n\n Returns:\n (Any): Model checkpoint if applicable, otherwise None.\n \"\"\"\n import torchvision # scope for faster 'import ultralytics'\n\n if str(self.model) in torchvision.models.__dict__:\n self.model = torchvision.models.__dict__[self.model](\n weights=\"IMAGENET1K_V1\" if self.args.pretrained else None\n )\n ckpt = None\n else:\n ckpt = super().setup_model()\n ClassificationModel.reshape_outputs(self.model, self.data[\"nc\"])\n return ckpt\n\n def build_dataset(self, img_path: str, mode: str = \"train\", batch=None):\n \"\"\"\n Create a ClassificationDataset instance given an image path and mode.\n\n Args:\n img_path (str): Path to the dataset images.\n mode (str, optional): Dataset mode ('train', 'val', or 'test').\n batch (Any, optional): Batch information (unused in this implementation).\n\n Returns:\n (ClassificationDataset): Dataset for the specified mode.\n \"\"\"\n return ClassificationDataset(root=img_path, args=self.args, augment=mode == \"train\", prefix=mode)\n\n def get_dataloader(self, dataset_path: str, batch_size: int = 16, rank: int = 0, mode: str = \"train\"):\n \"\"\"\n Return PyTorch DataLoader with transforms to preprocess images.\n\n Args:\n dataset_path (str): Path to the dataset.\n batch_size (int, optional): Number of images per batch.\n rank (int, optional): Process rank for distributed training.\n mode (str, optional): 'train', 'val', or 'test' mode.\n\n Returns:\n (torch.utils.data.DataLoader): DataLoader for the specified dataset and mode.\n \"\"\"\n with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP\n dataset = self.build_dataset(dataset_path, mode)\n\n loader = build_dataloader(dataset, batch_size, self.args.workers, rank=rank)\n # Attach inference transforms\n if mode != \"train\":\n if is_parallel(self.model):\n self.model.module.transforms = loader.dataset.torch_transforms\n else:\n self.model.transforms = loader.dataset.torch_transforms\n return loader\n\n def preprocess_batch(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:\n \"\"\"Preprocess a batch of images and classes.\"\"\"\n batch[\"img\"] = batch[\"img\"].to(self.device)\n batch[\"cls\"] = batch[\"cls\"].to(self.device)\n return batch\n\n def progress_string(self) -> str:\n \"\"\"Return a formatted string showing training progress.\"\"\"\n return (\"\\n\" + \"%11s\" * (4 + len(self.loss_names))) % (\n \"Epoch\",\n \"GPU_mem\",\n *self.loss_names,\n \"Instances\",\n \"Size\",\n )\n\n def get_validator(self):\n \"\"\"Return an instance of ClassificationValidator for validation.\"\"\"\n self.loss_names = [\"loss\"]\n return yolo.classify.ClassificationValidator(\n self.test_loader, self.save_dir, args=copy(self.args), _callbacks=self.callbacks\n )\n\n def label_loss_items(self, loss_items: Optional[torch.Tensor] = None, prefix: str = \"train\"):\n \"\"\"\n Return a loss dict with labelled training loss items tensor.\n\n Args:\n loss_items (torch.Tensor, optional): Loss tensor items.\n prefix (str, optional): Prefix to prepend to loss names.\n\n Returns:\n keys (List[str]): List of loss keys if loss_items is None.\n loss_dict (Dict[str, float]): Dictionary of loss items if loss_items is provided.\n \"\"\"\n keys = [f\"{prefix}/{x}\" for x in self.loss_names]\n if loss_items is None:\n return keys\n loss_items = [round(float(loss_items), 5)]\n return dict(zip(keys, loss_items))\n\n def plot_metrics(self):\n \"\"\"Plot metrics from a CSV file.\"\"\"\n plot_results(file=self.csv, classify=True, on_plot=self.on_plot) # save results.png\n\n def final_eval(self):\n \"\"\"Evaluate trained model and save validation results.\"\"\"\n for f in self.last, self.best:\n if f.exists():\n strip_optimizer(f) # strip optimizers\n if f is self.best:\n LOGGER.info(f\"\\nValidating {f}...\")\n self.validator.args.data = self.args.data\n self.validator.args.plots = self.args.plots\n self.metrics = self.validator(model=f)\n self.metrics.pop(\"fitness\", None)\n self.run_callbacks(\"on_fit_epoch_end\")\n\n def plot_training_samples(self, batch: Dict[str, torch.Tensor], ni: int):\n \"\"\"\n Plot training samples with their annotations.\n\n Args:\n batch (Dict[str, torch.Tensor]): Batch containing images and class labels.\n ni (int): Number of iterations.\n \"\"\"\n batch[\"batch_idx\"] = torch.arange(len(batch[\"img\"])) # add batch index for plotting\n plot_images(\n labels=batch,\n fname=self.save_dir / f\"train_batch{ni}.jpg\",\n on_plot=self.on_plot,\n )", "chunk_type": "class", "name": "ClassificationTrainer", "file_path": "ultralytics\\ultralytics\\models\\yolo\\classify\\train.py", "start_line": 17, "end_line": 236, "start_col": 0, "end_col": 9, "parent_name": null, "docstring": "A trainer class extending BaseTrainer for training image classification models.\n\nThis trainer handles the training process for image classification tasks, supporting both YOLO classification models\nand torchvision models with comprehensive dataset handling and validation.\n\nAttributes:\n model (ClassificationModel): The classification model to be trained.\n data (Dict[str, Any]): Dictionary containing dataset information including class names and number of classes.\n loss_names (List[str]): Names of the loss functions used during training.\n validator (ClassificationValidator): Validator instance for model evaluation.\n\nMethods:\n set_model_attributes: Set the model's class names from the loaded dataset.\n get_model: Return a modified PyTorch model configured for training.\n setup_model: Load, create or download model for classification.\n build_dataset: Create a ClassificationDataset instance.\n get_dataloader: Return PyTorch DataLoader with transforms for image preprocessing.\n preprocess_batch: Preprocess a batch of images and classes.\n progress_string: Return a formatted string showing training progress.\n get_validator: Return an instance of ClassificationValidator.\n label_loss_items: Return a loss dict with labelled training loss items.\n plot_metrics: Plot metrics from a CSV file.\n final_eval: Evaluate trained model and save validation results.\n plot_training_samples: Plot training samples with their annotations.\n\nExamples:\n Initialize and train a classification model\n >>> from ultralytics.models.yolo.classify import ClassificationTrainer\n >>> args = dict(model=\"yolo11n-cls.pt\", data=\"imagenet10\", epochs=3)\n >>> trainer = ClassificationTrainer(overrides=args)\n >>> trainer.train()", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "copy.copy", "typing.Any", "typing.Dict", "typing.Optional", "torch", "ultralytics.data.ClassificationDataset", "ultralytics.data.build_dataloader", "ultralytics.engine.trainer.BaseTrainer", "ultralytics.models.yolo", "ultralytics.nn.tasks.ClassificationModel", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.LOGGER", "ultralytics.utils.RANK", "ultralytics.utils.plotting.plot_images", "ultralytics.utils.plotting.plot_results", "ultralytics.utils.torch_utils.is_parallel", "ultralytics.utils.torch_utils.strip_optimizer", "ultralytics.utils.torch_utils.torch_distributed_zero_first", "torchvision", "BaseTrainer" ], "chunk_id": "class_ClassificationTrainer_dd748521" }, { "content": "from pathlib import Path", "chunk_type": "import", "name": "Path", "file_path": "ultralytics\\ultralytics\\models\\yolo\\classify\\val.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Path_3632a29c" }, { "content": "from typing import Any, Dict, List, Tuple, Union", "chunk_type": "import", "name": "Any, Dict, List, Tuple, Union", "file_path": "ultralytics\\ultralytics\\models\\yolo\\classify\\val.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 48, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Any, Dict, List, Tuple, Union_63990ab4" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\ultralytics\\models\\yolo\\classify\\val.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_3f12a6f1" }, { "content": "from ultralytics.data import ClassificationDataset, build_dataloader", "chunk_type": "import", "name": "ClassificationDataset, build_dataloader", "file_path": "ultralytics\\ultralytics\\models\\yolo\\classify\\val.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 68, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_ClassificationDataset, build_dataloader_5bab6936" }, { "content": "from ultralytics.engine.validator import BaseValidator", "chunk_type": "import", "name": "BaseValidator", "file_path": "ultralytics\\ultralytics\\models\\yolo\\classify\\val.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 54, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_BaseValidator_fbeb8d27" }, { "content": "from ultralytics.utils import LOGGER", "chunk_type": "import", "name": "LOGGER", "file_path": "ultralytics\\ultralytics\\models\\yolo\\classify\\val.py", "start_line": 10, "end_line": 10, "start_col": 0, "end_col": 36, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_LOGGER_c5ddaab2" }, { "content": "from ultralytics.utils.metrics import ClassifyMetrics, ConfusionMatrix", "chunk_type": "import", "name": "ClassifyMetrics, ConfusionMatrix", "file_path": "ultralytics\\ultralytics\\models\\yolo\\classify\\val.py", "start_line": 11, "end_line": 11, "start_col": 0, "end_col": 70, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_ClassifyMetrics, ConfusionMatrix_76591643" }, { "content": "from ultralytics.utils.plotting import plot_images", "chunk_type": "import", "name": "plot_images", "file_path": "ultralytics\\ultralytics\\models\\yolo\\classify\\val.py", "start_line": 12, "end_line": 12, "start_col": 0, "end_col": 50, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_plot_images_121c0ec7" }, { "content": "class ClassificationValidator(BaseValidator):\n \"\"\"\n A class extending the BaseValidator class for validation based on a classification model.\n\n This validator handles the validation process for classification models, including metrics calculation,\n confusion matrix generation, and visualization of results.\n\n Attributes:\n targets (List[torch.Tensor]): Ground truth class labels.\n pred (List[torch.Tensor]): Model predictions.\n metrics (ClassifyMetrics): Object to calculate and store classification metrics.\n names (dict): Mapping of class indices to class names.\n nc (int): Number of classes.\n confusion_matrix (ConfusionMatrix): Matrix to evaluate model performance across classes.\n\n Methods:\n get_desc: Return a formatted string summarizing classification metrics.\n init_metrics: Initialize confusion matrix, class names, and tracking containers.\n preprocess: Preprocess input batch by moving data to device.\n update_metrics: Update running metrics with model predictions and batch targets.\n finalize_metrics: Finalize metrics including confusion matrix and processing speed.\n postprocess: Extract the primary prediction from model output.\n get_stats: Calculate and return a dictionary of metrics.\n build_dataset: Create a ClassificationDataset instance for validation.\n get_dataloader: Build and return a data loader for classification validation.\n print_results: Print evaluation metrics for the classification model.\n plot_val_samples: Plot validation image samples with their ground truth labels.\n plot_predictions: Plot images with their predicted class labels.\n\n Examples:\n >>> from ultralytics.models.yolo.classify import ClassificationValidator\n >>> args = dict(model=\"yolo11n-cls.pt\", data=\"imagenet10\")\n >>> validator = ClassificationValidator(args=args)\n >>> validator()\n\n Notes:\n Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.\n \"\"\"\n\n def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:\n \"\"\"\n Initialize ClassificationValidator with dataloader, save directory, and other parameters.\n\n Args:\n dataloader (torch.utils.data.DataLoader, optional): Dataloader to use for validation.\n save_dir (str | Path, optional): Directory to save results.\n args (dict, optional): Arguments containing model and validation configuration.\n _callbacks (list, optional): List of callback functions to be called during validation.\n\n Examples:\n >>> from ultralytics.models.yolo.classify import ClassificationValidator\n >>> args = dict(model=\"yolo11n-cls.pt\", data=\"imagenet10\")\n >>> validator = ClassificationValidator(args=args)\n >>> validator()\n \"\"\"\n super().__init__(dataloader, save_dir, args, _callbacks)\n self.targets = None\n self.pred = None\n self.args.task = \"classify\"\n self.metrics = ClassifyMetrics()\n\n def get_desc(self) -> str:\n \"\"\"Return a formatted string summarizing classification metrics.\"\"\"\n return (\"%22s\" + \"%11s\" * 2) % (\"classes\", \"top1_acc\", \"top5_acc\")\n\n def init_metrics(self, model: torch.nn.Module) -> None:\n \"\"\"Initialize confusion matrix, class names, and tracking containers for predictions and targets.\"\"\"\n self.names = model.names\n self.nc = len(model.names)\n self.pred = []\n self.targets = []\n self.confusion_matrix = ConfusionMatrix(names=list(model.names.values()))\n\n def preprocess(self, batch: Dict[str, Any]) -> Dict[str, Any]:\n \"\"\"Preprocess input batch by moving data to device and converting to appropriate dtype.\"\"\"\n batch[\"img\"] = batch[\"img\"].to(self.device, non_blocking=True)\n batch[\"img\"] = batch[\"img\"].half() if self.args.half else batch[\"img\"].float()\n batch[\"cls\"] = batch[\"cls\"].to(self.device)\n return batch\n\n def update_metrics(self, preds: torch.Tensor, batch: Dict[str, Any]) -> None:\n \"\"\"\n Update running metrics with model predictions and batch targets.\n\n Args:\n preds (torch.Tensor): Model predictions, typically logits or probabilities for each class.\n batch (dict): Batch data containing images and class labels.\n\n Notes:\n This method appends the top-N predictions (sorted by confidence in descending order) to the\n prediction list for later evaluation. N is limited to the minimum of 5 and the number of classes.\n \"\"\"\n n5 = min(len(self.names), 5)\n self.pred.append(preds.argsort(1, descending=True)[:, :n5].type(torch.int32).cpu())\n self.targets.append(batch[\"cls\"].type(torch.int32).cpu())\n\n def finalize_metrics(self) -> None:\n \"\"\"\n Finalize metrics including confusion matrix and processing speed.\n\n Notes:\n This method processes the accumulated predictions and targets to generate the confusion matrix,\n optionally plots it, and updates the metrics object with speed information.\n\n Examples:\n >>> validator = ClassificationValidator()\n >>> validator.pred = [torch.tensor([[0, 1, 2]])] # Top-3 predictions for one sample\n >>> validator.targets = [torch.tensor([0])] # Ground truth class\n >>> validator.finalize_metrics()\n >>> print(validator.metrics.confusion_matrix) # Access the confusion matrix\n \"\"\"\n self.confusion_matrix.process_cls_preds(self.pred, self.targets)\n if self.args.plots:\n for normalize in True, False:\n self.confusion_matrix.plot(save_dir=self.save_dir, normalize=normalize, on_plot=self.on_plot)\n self.metrics.speed = self.speed\n self.metrics.save_dir = self.save_dir\n self.metrics.confusion_matrix = self.confusion_matrix\n\n def postprocess(self, preds: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]) -> torch.Tensor:\n \"\"\"Extract the primary prediction from model output if it's in a list or tuple format.\"\"\"\n return preds[0] if isinstance(preds, (list, tuple)) else preds\n\n def get_stats(self) -> Dict[str, float]:\n \"\"\"Calculate and return a dictionary of metrics by processing targets and predictions.\"\"\"\n self.metrics.process(self.targets, self.pred)\n return self.metrics.results_dict\n\n def build_dataset(self, img_path: str) -> ClassificationDataset:\n \"\"\"Create a ClassificationDataset instance for validation.\"\"\"\n return ClassificationDataset(root=img_path, args=self.args, augment=False, prefix=self.args.split)\n\n def get_dataloader(self, dataset_path: Union[Path, str], batch_size: int) -> torch.utils.data.DataLoader:\n \"\"\"\n Build and return a data loader for classification validation.\n\n Args:\n dataset_path (str | Path): Path to the dataset directory.\n batch_size (int): Number of samples per batch.\n\n Returns:\n (torch.utils.data.DataLoader): DataLoader object for the classification validation dataset.\n \"\"\"\n dataset = self.build_dataset(dataset_path)\n return build_dataloader(dataset, batch_size, self.args.workers, rank=-1)\n\n def print_results(self) -> None:\n \"\"\"Print evaluation metrics for the classification model.\"\"\"\n pf = \"%22s\" + \"%11.3g\" * len(self.metrics.keys) # print format\n LOGGER.info(pf % (\"all\", self.metrics.top1, self.metrics.top5))\n\n def plot_val_samples(self, batch: Dict[str, Any], ni: int) -> None:\n \"\"\"\n Plot validation image samples with their ground truth labels.\n\n Args:\n batch (Dict[str, Any]): Dictionary containing batch data with 'img' (images) and 'cls' (class labels).\n ni (int): Batch index used for naming the output file.\n\n Examples:\n >>> validator = ClassificationValidator()\n >>> batch = {\"img\": torch.rand(16, 3, 224, 224), \"cls\": torch.randint(0, 10, (16,))}\n >>> validator.plot_val_samples(batch, 0)\n \"\"\"\n batch[\"batch_idx\"] = torch.arange(len(batch[\"img\"])) # add batch index for plotting\n plot_images(\n labels=batch,\n fname=self.save_dir / f\"val_batch{ni}_labels.jpg\",\n names=self.names,\n on_plot=self.on_plot,\n )\n\n def plot_predictions(self, batch: Dict[str, Any], preds: torch.Tensor, ni: int) -> None:\n \"\"\"\n Plot images with their predicted class labels and save the visualization.\n\n Args:\n batch (Dict[str, Any]): Batch data containing images and other information.\n preds (torch.Tensor): Model predictions with shape (batch_size, num_classes).\n ni (int): Batch index used for naming the output file.\n\n Examples:\n >>> validator = ClassificationValidator()\n >>> batch = {\"img\": torch.rand(16, 3, 224, 224)}\n >>> preds = torch.rand(16, 10) # 16 images, 10 classes\n >>> validator.plot_predictions(batch, preds, 0)\n \"\"\"\n batched_preds = dict(\n img=batch[\"img\"],\n batch_idx=torch.arange(len(batch[\"img\"])),\n cls=torch.argmax(preds, dim=1),\n )\n plot_images(\n batched_preds,\n fname=self.save_dir / f\"val_batch{ni}_pred.jpg\",\n names=self.names,\n on_plot=self.on_plot,\n ) # pred", "chunk_type": "class", "name": "ClassificationValidator", "file_path": "ultralytics\\ultralytics\\models\\yolo\\classify\\val.py", "start_line": 15, "end_line": 212, "start_col": 0, "end_col": 9, "parent_name": null, "docstring": "A class extending the BaseValidator class for validation based on a classification model.\n\nThis validator handles the validation process for classification models, including metrics calculation,\nconfusion matrix generation, and visualization of results.\n\nAttributes:\n targets (List[torch.Tensor]): Ground truth class labels.\n pred (List[torch.Tensor]): Model predictions.\n metrics (ClassifyMetrics): Object to calculate and store classification metrics.\n names (dict): Mapping of class indices to class names.\n nc (int): Number of classes.\n confusion_matrix (ConfusionMatrix): Matrix to evaluate model performance across classes.\n\nMethods:\n get_desc: Return a formatted string summarizing classification metrics.\n init_metrics: Initialize confusion matrix, class names, and tracking containers.\n preprocess: Preprocess input batch by moving data to device.\n update_metrics: Update running metrics with model predictions and batch targets.\n finalize_metrics: Finalize metrics including confusion matrix and processing speed.\n postprocess: Extract the primary prediction from model output.\n get_stats: Calculate and return a dictionary of metrics.\n build_dataset: Create a ClassificationDataset instance for validation.\n get_dataloader: Build and return a data loader for classification validation.\n print_results: Print evaluation metrics for the classification model.\n plot_val_samples: Plot validation image samples with their ground truth labels.\n plot_predictions: Plot images with their predicted class labels.\n\nExamples:\n >>> from ultralytics.models.yolo.classify import ClassificationValidator\n >>> args = dict(model=\"yolo11n-cls.pt\", data=\"imagenet10\")\n >>> validator = ClassificationValidator(args=args)\n >>> validator()\n\nNotes:\n Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "pathlib.Path", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "typing.Union", "torch", "ultralytics.data.ClassificationDataset", "ultralytics.data.build_dataloader", "ultralytics.engine.validator.BaseValidator", "ultralytics.utils.LOGGER", "ultralytics.utils.metrics.ClassifyMetrics", "ultralytics.utils.metrics.ConfusionMatrix", "ultralytics.utils.plotting.plot_images", "BaseValidator" ], "chunk_id": "class_ClassificationValidator_e697968f" }, { "content": "from ultralytics.models.yolo.classify.predict import ClassificationPredictor", "chunk_type": "import", "name": "ClassificationPredictor", "file_path": "ultralytics\\ultralytics\\models\\yolo\\classify\\__init__.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 76, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_ClassificationPredictor_fc0b1be3" }, { "content": "from ultralytics.models.yolo.classify.train import ClassificationTrainer", "chunk_type": "import", "name": "ClassificationTrainer", "file_path": "ultralytics\\ultralytics\\models\\yolo\\classify\\__init__.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 72, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_ClassificationTrainer_b0059f59" }, { "content": "from ultralytics.models.yolo.classify.val import ClassificationValidator", "chunk_type": "import", "name": "ClassificationValidator", "file_path": "ultralytics\\ultralytics\\models\\yolo\\classify\\__init__.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 72, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_ClassificationValidator_463f349f" }, { "content": "__all__ = \"ClassificationPredictor\", \"ClassificationTrainer\", \"ClassificationValidator\"", "chunk_type": "variable", "name": "__all__", "file_path": "ultralytics\\ultralytics\\models\\yolo\\classify\\__init__.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 87, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable___all___ab6e9af8" }, { "content": "from ultralytics.engine.predictor import BasePredictor", "chunk_type": "import", "name": "BasePredictor", "file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\predict.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 54, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_BasePredictor_da3146f8" }, { "content": "from ultralytics.engine.results import Results", "chunk_type": "import", "name": "Results", "file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\predict.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 46, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Results_360d0d86" }, { "content": "from ultralytics.utils import ops", "chunk_type": "import", "name": "ops", "file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\predict.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 33, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_ops_eb03b6c3" }, { "content": "class DetectionPredictor(BasePredictor):\n \"\"\"\n A class extending the BasePredictor class for prediction based on a detection model.\n\n This predictor specializes in object detection tasks, processing model outputs into meaningful detection results\n with bounding boxes and class predictions.\n\n Attributes:\n args (namespace): Configuration arguments for the predictor.\n model (nn.Module): The detection model used for inference.\n batch (list): Batch of images and metadata for processing.\n\n Methods:\n postprocess: Process raw model predictions into detection results.\n construct_results: Build Results objects from processed predictions.\n construct_result: Create a single Result object from a prediction.\n get_obj_feats: Extract object features from the feature maps.\n\n Examples:\n >>> from ultralytics.utils import ASSETS\n >>> from ultralytics.models.yolo.detect import DetectionPredictor\n >>> args = dict(model=\"yolo11n.pt\", source=ASSETS)\n >>> predictor = DetectionPredictor(overrides=args)\n >>> predictor.predict_cli()\n \"\"\"\n\n def postprocess(self, preds, img, orig_imgs, **kwargs):\n \"\"\"\n Post-process predictions and return a list of Results objects.\n\n This method applies non-maximum suppression to raw model predictions and prepares them for visualization and\n further analysis.\n\n Args:\n preds (torch.Tensor): Raw predictions from the model.\n img (torch.Tensor): Processed input image tensor in model input format.\n orig_imgs (torch.Tensor | list): Original input images before preprocessing.\n **kwargs (Any): Additional keyword arguments.\n\n Returns:\n (list): List of Results objects containing the post-processed predictions.\n\n Examples:\n >>> predictor = DetectionPredictor(overrides=dict(model=\"yolo11n.pt\"))\n >>> results = predictor.predict(\"path/to/image.jpg\")\n >>> processed_results = predictor.postprocess(preds, img, orig_imgs)\n \"\"\"\n save_feats = getattr(self, \"_feats\", None) is not None\n preds = ops.non_max_suppression(\n preds,\n self.args.conf,\n self.args.iou,\n self.args.classes,\n self.args.agnostic_nms,\n max_det=self.args.max_det,\n nc=0 if self.args.task == \"detect\" else len(self.model.names),\n end2end=getattr(self.model, \"end2end\", False),\n rotated=self.args.task == \"obb\",\n return_idxs=save_feats,\n )\n\n if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list\n orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)\n\n if save_feats:\n obj_feats = self.get_obj_feats(self._feats, preds[1])\n preds = preds[0]\n\n results = self.construct_results(preds, img, orig_imgs, **kwargs)\n\n if save_feats:\n for r, f in zip(results, obj_feats):\n r.feats = f # add object features to results\n\n return results\n\n def get_obj_feats(self, feat_maps, idxs):\n \"\"\"Extract object features from the feature maps.\"\"\"\n import torch\n\n s = min([x.shape[1] for x in feat_maps]) # find smallest vector length\n obj_feats = torch.cat(\n [x.permute(0, 2, 3, 1).reshape(x.shape[0], -1, s, x.shape[1] // s).mean(dim=-1) for x in feat_maps], dim=1\n ) # mean reduce all vectors to same length\n return [feats[idx] if len(idx) else [] for feats, idx in zip(obj_feats, idxs)] # for each img in batch\n\n def construct_results(self, preds, img, orig_imgs):\n \"\"\"\n Construct a list of Results objects from model predictions.\n\n Args:\n preds (List[torch.Tensor]): List of predicted bounding boxes and scores for each image.\n img (torch.Tensor): Batch of preprocessed images used for inference.\n orig_imgs (List[np.ndarray]): List of original images before preprocessing.\n\n Returns:\n (List[Results]): List of Results objects containing detection information for each image.\n \"\"\"\n return [\n self.construct_result(pred, img, orig_img, img_path)\n for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0])\n ]\n\n def construct_result(self, pred, img, orig_img, img_path):\n \"\"\"\n Construct a single Results object from one image prediction.\n\n Args:\n pred (torch.Tensor): Predicted boxes and scores with shape (N, 6) where N is the number of detections.\n img (torch.Tensor): Preprocessed image tensor used for inference.\n orig_img (np.ndarray): Original image before preprocessing.\n img_path (str): Path to the original image file.\n\n Returns:\n (Results): Results object containing the original image, image path, class names, and scaled bounding boxes.\n \"\"\"\n pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)\n return Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6])", "chunk_type": "class", "name": "DetectionPredictor", "file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\predict.py", "start_line": 8, "end_line": 125, "start_col": 0, "end_col": 90, "parent_name": null, "docstring": "A class extending the BasePredictor class for prediction based on a detection model.\n\nThis predictor specializes in object detection tasks, processing model outputs into meaningful detection results\nwith bounding boxes and class predictions.\n\nAttributes:\n args (namespace): Configuration arguments for the predictor.\n model (nn.Module): The detection model used for inference.\n batch (list): Batch of images and metadata for processing.\n\nMethods:\n postprocess: Process raw model predictions into detection results.\n construct_results: Build Results objects from processed predictions.\n construct_result: Create a single Result object from a prediction.\n get_obj_feats: Extract object features from the feature maps.\n\nExamples:\n >>> from ultralytics.utils import ASSETS\n >>> from ultralytics.models.yolo.detect import DetectionPredictor\n >>> args = dict(model=\"yolo11n.pt\", source=ASSETS)\n >>> predictor = DetectionPredictor(overrides=args)\n >>> predictor.predict_cli()", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "ultralytics.engine.predictor.BasePredictor", "ultralytics.engine.results.Results", "ultralytics.utils.ops", "torch", "BasePredictor" ], "chunk_id": "class_DetectionPredictor_3c9d40be" }, { "content": "import math", "chunk_type": "import", "name": "math", "file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\train.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 11, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_math_f0299af0" }, { "content": "import random", "chunk_type": "import", "name": "random", "file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\train.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 13, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_random_91b66083" }, { "content": "from copy import copy", "chunk_type": "import", "name": "copy", "file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\train.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 21, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_copy_a93f4f48" }, { "content": "from typing import Any, Dict, List, Optional", "chunk_type": "import", "name": "Any, Dict, List, Optional", "file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\train.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 44, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Any, Dict, List, Optional_cea6eb92" }, { "content": "import numpy as np", "chunk_type": "import", "name": "numpy", "file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\train.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_numpy_d326ad88" }, { "content": "import torch.nn as nn", "chunk_type": "import", "name": "torch.nn", "file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\train.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 21, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch.nn_0a5f5c0d" }, { "content": "from ultralytics.data import build_dataloader, build_yolo_dataset", "chunk_type": "import", "name": "build_dataloader, build_yolo_dataset", "file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\train.py", "start_line": 11, "end_line": 11, "start_col": 0, "end_col": 65, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_build_dataloader, build_yolo_dataset_4e1469bb" }, { "content": "from ultralytics.engine.trainer import BaseTrainer", "chunk_type": "import", "name": "BaseTrainer", "file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\train.py", "start_line": 12, "end_line": 12, "start_col": 0, "end_col": 50, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_BaseTrainer_4b4cf5ee" }, { "content": "from ultralytics.models import yolo", "chunk_type": "import", "name": "yolo", "file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\train.py", "start_line": 13, "end_line": 13, "start_col": 0, "end_col": 35, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_yolo_7f0d97b6" }, { "content": "from ultralytics.nn.tasks import DetectionModel", "chunk_type": "import", "name": "DetectionModel", "file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\train.py", "start_line": 14, "end_line": 14, "start_col": 0, "end_col": 47, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_DetectionModel_f9862d77" }, { "content": "from ultralytics.utils import LOGGER, RANK", "chunk_type": "import", "name": "LOGGER, RANK", "file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\train.py", "start_line": 15, "end_line": 15, "start_col": 0, "end_col": 42, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_LOGGER, RANK_5098cbb2" }, { "content": "from ultralytics.utils.patches import override_configs", "chunk_type": "import", "name": "override_configs", "file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\train.py", "start_line": 16, "end_line": 16, "start_col": 0, "end_col": 54, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_override_configs_db25ed44" }, { "content": "from ultralytics.utils.plotting import plot_images, plot_labels, plot_results", "chunk_type": "import", "name": "plot_images, plot_labels, plot_results", "file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\train.py", "start_line": 17, "end_line": 17, "start_col": 0, "end_col": 77, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_plot_images, plot_labels, plot_results_7ca697a2" }, { "content": "from ultralytics.utils.torch_utils import de_parallel, torch_distributed_zero_first", "chunk_type": "import", "name": "de_parallel, torch_distributed_zero_first", "file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\train.py", "start_line": 18, "end_line": 18, "start_col": 0, "end_col": 83, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_de_parallel, torch_distributed_zero_first_585c0ffc" }, { "content": "class DetectionTrainer(BaseTrainer):\n \"\"\"\n A class extending the BaseTrainer class for training based on a detection model.\n\n This trainer specializes in object detection tasks, handling the specific requirements for training YOLO models\n for object detection including dataset building, data loading, preprocessing, and model configuration.\n\n Attributes:\n model (DetectionModel): The YOLO detection model being trained.\n data (Dict): Dictionary containing dataset information including class names and number of classes.\n loss_names (tuple): Names of the loss components used in training (box_loss, cls_loss, dfl_loss).\n\n Methods:\n build_dataset: Build YOLO dataset for training or validation.\n get_dataloader: Construct and return dataloader for the specified mode.\n preprocess_batch: Preprocess a batch of images by scaling and converting to float.\n set_model_attributes: Set model attributes based on dataset information.\n get_model: Return a YOLO detection model.\n get_validator: Return a validator for model evaluation.\n label_loss_items: Return a loss dictionary with labeled training loss items.\n progress_string: Return a formatted string of training progress.\n plot_training_samples: Plot training samples with their annotations.\n plot_metrics: Plot metrics from a CSV file.\n plot_training_labels: Create a labeled training plot of the YOLO model.\n auto_batch: Calculate optimal batch size based on model memory requirements.\n\n Examples:\n >>> from ultralytics.models.yolo.detect import DetectionTrainer\n >>> args = dict(model=\"yolo11n.pt\", data=\"coco8.yaml\", epochs=3)\n >>> trainer = DetectionTrainer(overrides=args)\n >>> trainer.train()\n \"\"\"\n\n def build_dataset(self, img_path: str, mode: str = \"train\", batch: Optional[int] = None):\n \"\"\"\n Build YOLO Dataset for training or validation.\n\n Args:\n img_path (str): Path to the folder containing images.\n mode (str): 'train' mode or 'val' mode, users are able to customize different augmentations for each mode.\n batch (int, optional): Size of batches, this is for 'rect' mode.\n\n Returns:\n (Dataset): YOLO dataset object configured for the specified mode.\n \"\"\"\n gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)\n return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == \"val\", stride=gs)\n\n def get_dataloader(self, dataset_path: str, batch_size: int = 16, rank: int = 0, mode: str = \"train\"):\n \"\"\"\n Construct and return dataloader for the specified mode.\n\n Args:\n dataset_path (str): Path to the dataset.\n batch_size (int): Number of images per batch.\n rank (int): Process rank for distributed training.\n mode (str): 'train' for training dataloader, 'val' for validation dataloader.\n\n Returns:\n (DataLoader): PyTorch dataloader object.\n \"\"\"\n assert mode in {\"train\", \"val\"}, f\"Mode must be 'train' or 'val', not {mode}.\"\n with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP\n dataset = self.build_dataset(dataset_path, mode, batch_size)\n shuffle = mode == \"train\"\n if getattr(dataset, \"rect\", False) and shuffle:\n LOGGER.warning(\"'rect=True' is incompatible with DataLoader shuffle, setting shuffle=False\")\n shuffle = False\n workers = self.args.workers if mode == \"train\" else self.args.workers * 2\n return build_dataloader(dataset, batch_size, workers, shuffle, rank) # return dataloader\n\n def preprocess_batch(self, batch: Dict) -> Dict:\n \"\"\"\n Preprocess a batch of images by scaling and converting to float.\n\n Args:\n batch (Dict): Dictionary containing batch data with 'img' tensor.\n\n Returns:\n (Dict): Preprocessed batch with normalized images.\n \"\"\"\n batch[\"img\"] = batch[\"img\"].to(self.device, non_blocking=True).float() / 255\n if self.args.multi_scale:\n imgs = batch[\"img\"]\n sz = (\n random.randrange(int(self.args.imgsz * 0.5), int(self.args.imgsz * 1.5 + self.stride))\n // self.stride\n * self.stride\n ) # size\n sf = sz / max(imgs.shape[2:]) # scale factor\n if sf != 1:\n ns = [\n math.ceil(x * sf / self.stride) * self.stride for x in imgs.shape[2:]\n ] # new shape (stretched to gs-multiple)\n imgs = nn.functional.interpolate(imgs, size=ns, mode=\"bilinear\", align_corners=False)\n batch[\"img\"] = imgs\n return batch\n\n def set_model_attributes(self):\n \"\"\"Set model attributes based on dataset information.\"\"\"\n # Nl = de_parallel(self.model).model[-1].nl # number of detection layers (to scale hyps)\n # self.args.box *= 3 / nl # scale to layers\n # self.args.cls *= self.data[\"nc\"] / 80 * 3 / nl # scale to classes and layers\n # self.args.cls *= (self.args.imgsz / 640) ** 2 * 3 / nl # scale to image size and layers\n self.model.nc = self.data[\"nc\"] # attach number of classes to model\n self.model.names = self.data[\"names\"] # attach class names to model\n self.model.args = self.args # attach hyperparameters to model\n # TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc\n\n def get_model(self, cfg: Optional[str] = None, weights: Optional[str] = None, verbose: bool = True):\n \"\"\"\n Return a YOLO detection model.\n\n Args:\n cfg (str, optional): Path to model configuration file.\n weights (str, optional): Path to model weights.\n verbose (bool): Whether to display model information.\n\n Returns:\n (DetectionModel): YOLO detection model.\n \"\"\"\n model = DetectionModel(cfg, nc=self.data[\"nc\"], ch=self.data[\"channels\"], verbose=verbose and RANK == -1)\n if weights:\n model.load(weights)\n return model\n\n def get_validator(self):\n \"\"\"Return a DetectionValidator for YOLO model validation.\"\"\"\n self.loss_names = \"box_loss\", \"cls_loss\", \"dfl_loss\"\n return yolo.detect.DetectionValidator(\n self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks\n )\n\n def label_loss_items(self, loss_items: Optional[List[float]] = None, prefix: str = \"train\"):\n \"\"\"\n Return a loss dict with labeled training loss items tensor.\n\n Args:\n loss_items (List[float], optional): List of loss values.\n prefix (str): Prefix for keys in the returned dictionary.\n\n Returns:\n (Dict | List): Dictionary of labeled loss items if loss_items is provided, otherwise list of keys.\n \"\"\"\n keys = [f\"{prefix}/{x}\" for x in self.loss_names]\n if loss_items is not None:\n loss_items = [round(float(x), 5) for x in loss_items] # convert tensors to 5 decimal place floats\n return dict(zip(keys, loss_items))\n else:\n return keys\n\n def progress_string(self):\n \"\"\"Return a formatted string of training progress with epoch, GPU memory, loss, instances and size.\"\"\"\n return (\"\\n\" + \"%11s\" * (4 + len(self.loss_names))) % (\n \"Epoch\",\n \"GPU_mem\",\n *self.loss_names,\n \"Instances\",\n \"Size\",\n )\n\n def plot_training_samples(self, batch: Dict[str, Any], ni: int) -> None:\n \"\"\"\n Plot training samples with their annotations.\n\n Args:\n batch (Dict[str, Any]): Dictionary containing batch data.\n ni (int): Number of iterations.\n \"\"\"\n plot_images(\n labels=batch,\n paths=batch[\"im_file\"],\n fname=self.save_dir / f\"train_batch{ni}.jpg\",\n on_plot=self.on_plot,\n )\n\n def plot_metrics(self):\n \"\"\"Plot metrics from a CSV file.\"\"\"\n plot_results(file=self.csv, on_plot=self.on_plot) # save results.png\n\n def plot_training_labels(self):\n \"\"\"Create a labeled training plot of the YOLO model.\"\"\"\n boxes = np.concatenate([lb[\"bboxes\"] for lb in self.train_loader.dataset.labels], 0)\n cls = np.concatenate([lb[\"cls\"] for lb in self.train_loader.dataset.labels], 0)\n plot_labels(boxes, cls.squeeze(), names=self.data[\"names\"], save_dir=self.save_dir, on_plot=self.on_plot)\n\n def auto_batch(self):\n \"\"\"\n Get optimal batch size by calculating memory occupation of model.\n\n Returns:\n (int): Optimal batch size.\n \"\"\"\n with override_configs(self.args, overrides={\"cache\": False}) as self.args:\n train_dataset = self.build_dataset(self.data[\"train\"], mode=\"train\", batch=16)\n max_num_obj = max(len(label[\"cls\"]) for label in train_dataset.labels) * 4 # 4 for mosaic augmentation\n del train_dataset # free memory\n return super().auto_batch(max_num_obj)", "chunk_type": "class", "name": "DetectionTrainer", "file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\train.py", "start_line": 21, "end_line": 218, "start_col": 0, "end_col": 46, "parent_name": null, "docstring": "A class extending the BaseTrainer class for training based on a detection model.\n\nThis trainer specializes in object detection tasks, handling the specific requirements for training YOLO models\nfor object detection including dataset building, data loading, preprocessing, and model configuration.\n\nAttributes:\n model (DetectionModel): The YOLO detection model being trained.\n data (Dict): Dictionary containing dataset information including class names and number of classes.\n loss_names (tuple): Names of the loss components used in training (box_loss, cls_loss, dfl_loss).\n\nMethods:\n build_dataset: Build YOLO dataset for training or validation.\n get_dataloader: Construct and return dataloader for the specified mode.\n preprocess_batch: Preprocess a batch of images by scaling and converting to float.\n set_model_attributes: Set model attributes based on dataset information.\n get_model: Return a YOLO detection model.\n get_validator: Return a validator for model evaluation.\n label_loss_items: Return a loss dictionary with labeled training loss items.\n progress_string: Return a formatted string of training progress.\n plot_training_samples: Plot training samples with their annotations.\n plot_metrics: Plot metrics from a CSV file.\n plot_training_labels: Create a labeled training plot of the YOLO model.\n auto_batch: Calculate optimal batch size based on model memory requirements.\n\nExamples:\n >>> from ultralytics.models.yolo.detect import DetectionTrainer\n >>> args = dict(model=\"yolo11n.pt\", data=\"coco8.yaml\", epochs=3)\n >>> trainer = DetectionTrainer(overrides=args)\n >>> trainer.train()", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "math", "random", "copy.copy", "typing.Any", "typing.Dict", "typing.List", "typing.Optional", "numpy", "torch.nn", "ultralytics.data.build_dataloader", "ultralytics.data.build_yolo_dataset", "ultralytics.engine.trainer.BaseTrainer", "ultralytics.models.yolo", "ultralytics.nn.tasks.DetectionModel", "ultralytics.utils.LOGGER", "ultralytics.utils.RANK", "ultralytics.utils.patches.override_configs", "ultralytics.utils.plotting.plot_images", "ultralytics.utils.plotting.plot_labels", "ultralytics.utils.plotting.plot_results", "ultralytics.utils.torch_utils.de_parallel", "ultralytics.utils.torch_utils.torch_distributed_zero_first", "BaseTrainer" ], "chunk_id": "class_DetectionTrainer_63774454" }, { "content": "import os", "chunk_type": "import", "name": "os", "file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\val.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 9, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_os_37352805" }, { "content": "from pathlib import Path", "chunk_type": "import", "name": "Path", "file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\val.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Path_41a2c22a" }, { "content": "from typing import Any, Dict, List, Optional, Tuple, Union", "chunk_type": "import", "name": "Any, Dict, List, Optional, Tuple, Union", "file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\val.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 58, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Any, Dict, List, Optional, Tuple, Union_f7770b17" }, { "content": "import numpy as np", "chunk_type": "import", "name": "numpy", "file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\val.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_numpy_4c28ded0" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\val.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_8baa87eb" }, { "content": "from ultralytics.data import build_dataloader, build_yolo_dataset, converter", "chunk_type": "import", "name": "build_dataloader, build_yolo_dataset, converter", "file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\val.py", "start_line": 10, "end_line": 10, "start_col": 0, "end_col": 76, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_build_dataloader, build_yolo_dataset, converter_105d54f6" }, { "content": "from ultralytics.engine.validator import BaseValidator", "chunk_type": "import", "name": "BaseValidator", "file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\val.py", "start_line": 11, "end_line": 11, "start_col": 0, "end_col": 54, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_BaseValidator_10db63cf" }, { "content": "from ultralytics.utils import LOGGER, ops", "chunk_type": "import", "name": "LOGGER, ops", "file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\val.py", "start_line": 12, "end_line": 12, "start_col": 0, "end_col": 41, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_LOGGER, ops_7a47d0e3" }, { "content": "from ultralytics.utils.checks import check_requirements", "chunk_type": "import", "name": "check_requirements", "file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\val.py", "start_line": 13, "end_line": 13, "start_col": 0, "end_col": 55, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_check_requirements_ba0bad5b" }, { "content": "from ultralytics.utils.metrics import ConfusionMatrix, DetMetrics, box_iou", "chunk_type": "import", "name": "ConfusionMatrix, DetMetrics, box_iou", "file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\val.py", "start_line": 14, "end_line": 14, "start_col": 0, "end_col": 74, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_ConfusionMatrix, DetMetrics, box_iou_e52987ac" }, { "content": "from ultralytics.utils.plotting import plot_images", "chunk_type": "import", "name": "plot_images", "file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\val.py", "start_line": 15, "end_line": 15, "start_col": 0, "end_col": 50, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_plot_images_e7d8b65c" }, { "content": "class DetectionValidator(BaseValidator):\n \"\"\"\n A class extending the BaseValidator class for validation based on a detection model.\n\n This class implements validation functionality specific to object detection tasks, including metrics calculation,\n prediction processing, and visualization of results.\n\n Attributes:\n is_coco (bool): Whether the dataset is COCO.\n is_lvis (bool): Whether the dataset is LVIS.\n class_map (List[int]): Mapping from model class indices to dataset class indices.\n metrics (DetMetrics): Object detection metrics calculator.\n iouv (torch.Tensor): IoU thresholds for mAP calculation.\n niou (int): Number of IoU thresholds.\n lb (List[Any]): List for storing ground truth labels for hybrid saving.\n jdict (List[Dict[str, Any]]): List for storing JSON detection results.\n stats (Dict[str, List[torch.Tensor]]): Dictionary for storing statistics during validation.\n\n Examples:\n >>> from ultralytics.models.yolo.detect import DetectionValidator\n >>> args = dict(model=\"yolo11n.pt\", data=\"coco8.yaml\")\n >>> validator = DetectionValidator(args=args)\n >>> validator()\n \"\"\"\n\n def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:\n \"\"\"\n Initialize detection validator with necessary variables and settings.\n\n Args:\n dataloader (torch.utils.data.DataLoader, optional): Dataloader to use for validation.\n save_dir (Path, optional): Directory to save results.\n args (Dict[str, Any], optional): Arguments for the validator.\n _callbacks (List[Any], optional): List of callback functions.\n \"\"\"\n super().__init__(dataloader, save_dir, args, _callbacks)\n self.is_coco = False\n self.is_lvis = False\n self.class_map = None\n self.args.task = \"detect\"\n self.iouv = torch.linspace(0.5, 0.95, 10) # IoU vector for mAP@0.5:0.95\n self.niou = self.iouv.numel()\n self.metrics = DetMetrics()\n\n def preprocess(self, batch: Dict[str, Any]) -> Dict[str, Any]:\n \"\"\"\n Preprocess batch of images for YOLO validation.\n\n Args:\n batch (Dict[str, Any]): Batch containing images and annotations.\n\n Returns:\n (Dict[str, Any]): Preprocessed batch.\n \"\"\"\n batch[\"img\"] = batch[\"img\"].to(self.device, non_blocking=True)\n batch[\"img\"] = (batch[\"img\"].half() if self.args.half else batch[\"img\"].float()) / 255\n for k in {\"batch_idx\", \"cls\", \"bboxes\"}:\n batch[k] = batch[k].to(self.device)\n\n return batch\n\n def init_metrics(self, model: torch.nn.Module) -> None:\n \"\"\"\n Initialize evaluation metrics for YOLO detection validation.\n\n Args:\n model (torch.nn.Module): Model to validate.\n \"\"\"\n val = self.data.get(self.args.split, \"\") # validation path\n self.is_coco = (\n isinstance(val, str)\n and \"coco\" in val\n and (val.endswith(f\"{os.sep}val2017.txt\") or val.endswith(f\"{os.sep}test-dev2017.txt\"))\n ) # is COCO\n self.is_lvis = isinstance(val, str) and \"lvis\" in val and not self.is_coco # is LVIS\n self.class_map = converter.coco80_to_coco91_class() if self.is_coco else list(range(1, len(model.names) + 1))\n self.args.save_json |= self.args.val and (self.is_coco or self.is_lvis) and not self.training # run final val\n self.names = model.names\n self.nc = len(model.names)\n self.end2end = getattr(model, \"end2end\", False)\n self.seen = 0\n self.jdict = []\n self.metrics.names = self.names\n self.confusion_matrix = ConfusionMatrix(names=list(model.names.values()))\n\n def get_desc(self) -> str:\n \"\"\"Return a formatted string summarizing class metrics of YOLO model.\"\"\"\n return (\"%22s\" + \"%11s\" * 6) % (\"Class\", \"Images\", \"Instances\", \"Box(P\", \"R\", \"mAP50\", \"mAP50-95)\")\n\n def postprocess(self, preds: torch.Tensor) -> List[Dict[str, torch.Tensor]]:\n \"\"\"\n Apply Non-maximum suppression to prediction outputs.\n\n Args:\n preds (torch.Tensor): Raw predictions from the model.\n\n Returns:\n (List[Dict[str, torch.Tensor]]): Processed predictions after NMS, where each dict contains\n 'bboxes', 'conf', 'cls', and 'extra' tensors.\n \"\"\"\n outputs = ops.non_max_suppression(\n preds,\n self.args.conf,\n self.args.iou,\n nc=0 if self.args.task == \"detect\" else self.nc,\n multi_label=True,\n agnostic=self.args.single_cls or self.args.agnostic_nms,\n max_det=self.args.max_det,\n end2end=self.end2end,\n rotated=self.args.task == \"obb\",\n )\n return [{\"bboxes\": x[:, :4], \"conf\": x[:, 4], \"cls\": x[:, 5], \"extra\": x[:, 6:]} for x in outputs]\n\n def _prepare_batch(self, si: int, batch: Dict[str, Any]) -> Dict[str, Any]:\n \"\"\"\n Prepare a batch of images and annotations for validation.\n\n Args:\n si (int): Batch index.\n batch (Dict[str, Any]): Batch data containing images and annotations.\n\n Returns:\n (Dict[str, Any]): Prepared batch with processed annotations.\n \"\"\"\n idx = batch[\"batch_idx\"] == si\n cls = batch[\"cls\"][idx].squeeze(-1)\n bbox = batch[\"bboxes\"][idx]\n ori_shape = batch[\"ori_shape\"][si]\n imgsz = batch[\"img\"].shape[2:]\n ratio_pad = batch[\"ratio_pad\"][si]\n if len(cls):\n bbox = ops.xywh2xyxy(bbox) * torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]] # target boxes\n ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad) # native-space labels\n return {\"cls\": cls, \"bboxes\": bbox, \"ori_shape\": ori_shape, \"imgsz\": imgsz, \"ratio_pad\": ratio_pad}\n\n def _prepare_pred(self, pred: Dict[str, torch.Tensor], pbatch: Dict[str, Any]) -> Dict[str, torch.Tensor]:\n \"\"\"\n Prepare predictions for evaluation against ground truth.\n\n Args:\n pred (Dict[str, torch.Tensor]): Post-processed predictions from the model.\n pbatch (Dict[str, Any]): Prepared batch information.\n\n Returns:\n (Dict[str, torch.Tensor]): Prepared predictions in native space.\n \"\"\"\n cls = pred[\"cls\"]\n if self.args.single_cls:\n cls *= 0\n # predn = pred.clone()\n bboxes = ops.scale_boxes(\n pbatch[\"imgsz\"], pred[\"bboxes\"].clone(), pbatch[\"ori_shape\"], ratio_pad=pbatch[\"ratio_pad\"]\n ) # native-space pred\n return {\"bboxes\": bboxes, \"conf\": pred[\"conf\"], \"cls\": cls}\n\n def update_metrics(self, preds: List[Dict[str, torch.Tensor]], batch: Dict[str, Any]) -> None:\n \"\"\"\n Update metrics with new predictions and ground truth.\n\n Args:\n preds (List[Dict[str, torch.Tensor]]): List of predictions from the model.\n batch (Dict[str, Any]): Batch data containing ground truth.\n \"\"\"\n for si, pred in enumerate(preds):\n self.seen += 1\n pbatch = self._prepare_batch(si, batch)\n predn = self._prepare_pred(pred, pbatch)\n\n cls = pbatch[\"cls\"].cpu().numpy()\n no_pred = len(predn[\"cls\"]) == 0\n self.metrics.update_stats(\n {\n **self._process_batch(predn, pbatch),\n \"target_cls\": cls,\n \"target_img\": np.unique(cls),\n \"conf\": np.zeros(0) if no_pred else predn[\"conf\"].cpu().numpy(),\n \"pred_cls\": np.zeros(0) if no_pred else predn[\"cls\"].cpu().numpy(),\n }\n )\n # Evaluate\n if self.args.plots:\n self.confusion_matrix.process_batch(predn, pbatch, conf=self.args.conf)\n\n if no_pred:\n continue\n\n # Save\n if self.args.save_json:\n self.pred_to_json(predn, batch[\"im_file\"][si])\n if self.args.save_txt:\n self.save_one_txt(\n predn,\n self.args.save_conf,\n pbatch[\"ori_shape\"],\n self.save_dir / \"labels\" / f\"{Path(batch['im_file'][si]).stem}.txt\",\n )\n\n def finalize_metrics(self) -> None:\n \"\"\"Set final values for metrics speed and confusion matrix.\"\"\"\n if self.args.plots:\n for normalize in True, False:\n self.confusion_matrix.plot(save_dir=self.save_dir, normalize=normalize, on_plot=self.on_plot)\n self.metrics.speed = self.speed\n self.metrics.confusion_matrix = self.confusion_matrix\n self.metrics.save_dir = self.save_dir\n\n def get_stats(self) -> Dict[str, Any]:\n \"\"\"\n Calculate and return metrics statistics.\n\n Returns:\n (Dict[str, Any]): Dictionary containing metrics results.\n \"\"\"\n self.metrics.process(save_dir=self.save_dir, plot=self.args.plots, on_plot=self.on_plot)\n self.metrics.clear_stats()\n return self.metrics.results_dict\n\n def print_results(self) -> None:\n \"\"\"Print training/validation set metrics per class.\"\"\"\n pf = \"%22s\" + \"%11i\" * 2 + \"%11.3g\" * len(self.metrics.keys) # print format\n LOGGER.info(pf % (\"all\", self.seen, self.metrics.nt_per_class.sum(), *self.metrics.mean_results()))\n if self.metrics.nt_per_class.sum() == 0:\n LOGGER.warning(f\"no labels found in {self.args.task} set, can not compute metrics without labels\")\n\n # Print results per class\n if self.args.verbose and not self.training and self.nc > 1 and len(self.metrics.stats):\n for i, c in enumerate(self.metrics.ap_class_index):\n LOGGER.info(\n pf\n % (\n self.names[c],\n self.metrics.nt_per_image[c],\n self.metrics.nt_per_class[c],\n *self.metrics.class_result(i),\n )\n )\n\n def _process_batch(self, preds: Dict[str, torch.Tensor], batch: Dict[str, Any]) -> Dict[str, np.ndarray]:\n \"\"\"\n Return correct prediction matrix.\n\n Args:\n preds (Dict[str, torch.Tensor]): Dictionary containing prediction data with 'bboxes' and 'cls' keys.\n batch (Dict[str, Any]): Batch dictionary containing ground truth data with 'bboxes' and 'cls' keys.\n\n Returns:\n (Dict[str, np.ndarray]): Dictionary containing 'tp' key with correct prediction matrix of shape (N, 10) for 10 IoU levels.\n \"\"\"\n if len(batch[\"cls\"]) == 0 or len(preds[\"cls\"]) == 0:\n return {\"tp\": np.zeros((len(preds[\"cls\"]), self.niou), dtype=bool)}\n iou = box_iou(batch[\"bboxes\"], preds[\"bboxes\"])\n return {\"tp\": self.match_predictions(preds[\"cls\"], batch[\"cls\"], iou).cpu().numpy()}\n\n def build_dataset(self, img_path: str, mode: str = \"val\", batch: Optional[int] = None) -> torch.utils.data.Dataset:\n \"\"\"\n Build YOLO Dataset.\n\n Args:\n img_path (str): Path to the folder containing images.\n mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.\n batch (int, optional): Size of batches, this is for `rect`.\n\n Returns:\n (Dataset): YOLO dataset.\n \"\"\"\n return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, stride=self.stride)\n\n def get_dataloader(self, dataset_path: str, batch_size: int) -> torch.utils.data.DataLoader:\n \"\"\"\n Construct and return dataloader.\n\n Args:\n dataset_path (str): Path to the dataset.\n batch_size (int): Size of each batch.\n\n Returns:\n (torch.utils.data.DataLoader): Dataloader for validation.\n \"\"\"\n dataset = self.build_dataset(dataset_path, batch=batch_size, mode=\"val\")\n return build_dataloader(dataset, batch_size, self.args.workers, shuffle=False, rank=-1) # return dataloader\n\n def plot_val_samples(self, batch: Dict[str, Any], ni: int) -> None:\n \"\"\"\n Plot validation image samples.\n\n Args:\n batch (Dict[str, Any]): Batch containing images and annotations.\n ni (int): Batch index.\n \"\"\"\n plot_images(\n labels=batch,\n paths=batch[\"im_file\"],\n fname=self.save_dir / f\"val_batch{ni}_labels.jpg\",\n names=self.names,\n on_plot=self.on_plot,\n )\n\n def plot_predictions(\n self, batch: Dict[str, Any], preds: List[Dict[str, torch.Tensor]], ni: int, max_det: Optional[int] = None\n ) -> None:\n \"\"\"\n Plot predicted bounding boxes on input images and save the result.\n\n Args:\n batch (Dict[str, Any]): Batch containing images and annotations.\n preds (List[Dict[str, torch.Tensor]]): List of predictions from the model.\n ni (int): Batch index.\n max_det (Optional[int]): Maximum number of detections to plot.\n \"\"\"\n # TODO: optimize this\n for i, pred in enumerate(preds):\n pred[\"batch_idx\"] = torch.ones_like(pred[\"conf\"]) * i # add batch index to predictions\n keys = preds[0].keys()\n max_det = max_det or self.args.max_det\n batched_preds = {k: torch.cat([x[k][:max_det] for x in preds], dim=0) for k in keys}\n # TODO: fix this\n batched_preds[\"bboxes\"][:, :4] = ops.xyxy2xywh(batched_preds[\"bboxes\"][:, :4]) # convert to xywh format\n plot_images(\n images=batch[\"img\"],\n labels=batched_preds,\n paths=batch[\"im_file\"],\n fname=self.save_dir / f\"val_batch{ni}_pred.jpg\",\n names=self.names,\n on_plot=self.on_plot,\n ) # pred\n\n def save_one_txt(self, predn: Dict[str, torch.Tensor], save_conf: bool, shape: Tuple[int, int], file: Path) -> None:\n \"\"\"\n Save YOLO detections to a txt file in normalized coordinates in a specific format.\n\n Args:\n predn (Dict[str, torch.Tensor]): Dictionary containing predictions with keys 'bboxes', 'conf', and 'cls'.\n save_conf (bool): Whether to save confidence scores.\n shape (Tuple[int, int]): Shape of the original image (height, width).\n file (Path): File path to save the detections.\n \"\"\"\n from ultralytics.engine.results import Results\n\n Results(\n np.zeros((shape[0], shape[1]), dtype=np.uint8),\n path=None,\n names=self.names,\n boxes=torch.cat([predn[\"bboxes\"], predn[\"conf\"].unsqueeze(-1), predn[\"cls\"].unsqueeze(-1)], dim=1),\n ).save_txt(file, save_conf=save_conf)\n\n def pred_to_json(self, predn: Dict[str, torch.Tensor], filename: str) -> None:\n \"\"\"\n Serialize YOLO predictions to COCO json format.\n\n Args:\n predn (Dict[str, torch.Tensor]): Predictions dictionary containing 'bboxes', 'conf', and 'cls' keys\n with bounding box coordinates, confidence scores, and class predictions.\n filename (str): Image filename.\n \"\"\"\n stem = Path(filename).stem\n image_id = int(stem) if stem.isnumeric() else stem\n box = ops.xyxy2xywh(predn[\"bboxes\"]) # xywh\n box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner\n for b, s, c in zip(box.tolist(), predn[\"conf\"].tolist(), predn[\"cls\"].tolist()):\n self.jdict.append(\n {\n \"image_id\": image_id,\n \"category_id\": self.class_map[int(c)],\n \"bbox\": [round(x, 3) for x in b],\n \"score\": round(s, 5),\n }\n )\n\n def eval_json(self, stats: Dict[str, Any]) -> Dict[str, Any]:\n \"\"\"\n Evaluate YOLO output in JSON format and return performance statistics.\n\n Args:\n stats (Dict[str, Any]): Current statistics dictionary.\n\n Returns:\n (Dict[str, Any]): Updated statistics dictionary with COCO/LVIS evaluation results.\n \"\"\"\n pred_json = self.save_dir / \"predictions.json\" # predictions\n anno_json = (\n self.data[\"path\"]\n / \"annotations\"\n / (\"instances_val2017.json\" if self.is_coco else f\"lvis_v1_{self.args.split}.json\")\n ) # annotations\n return self.coco_evaluate(stats, pred_json, anno_json)\n\n def coco_evaluate(\n self,\n stats: Dict[str, Any],\n pred_json: str,\n anno_json: str,\n iou_types: Union[str, List[str]] = \"bbox\",\n suffix: Union[str, List[str]] = \"Box\",\n ) -> Dict[str, Any]:\n \"\"\"\n Evaluate COCO/LVIS metrics using faster-coco-eval library.\n\n Performs evaluation using the faster-coco-eval library to compute mAP metrics\n for object detection. Updates the provided stats dictionary with computed metrics\n including mAP50, mAP50-95, and LVIS-specific metrics if applicable.\n\n Args:\n stats (Dict[str, Any]): Dictionary to store computed metrics and statistics.\n pred_json (str | Path]): Path to JSON file containing predictions in COCO format.\n anno_json (str | Path]): Path to JSON file containing ground truth annotations in COCO format.\n iou_types (str | List[str]]): IoU type(s) for evaluation. Can be single string or list of strings.\n Common values include \"bbox\", \"segm\", \"keypoints\". Defaults to \"bbox\".\n suffix (str | List[str]]): Suffix to append to metric names in stats dictionary. Should correspond\n to iou_types if multiple types provided. Defaults to \"Box\".\n\n Returns:\n (Dict[str, Any]): Updated stats dictionary containing the computed COCO/LVIS evaluation metrics.\n \"\"\"\n if self.args.save_json and (self.is_coco or self.is_lvis) and len(self.jdict):\n LOGGER.info(f\"\\nEvaluating faster-coco-eval mAP using {pred_json} and {anno_json}...\")\n try:\n for x in pred_json, anno_json:\n assert x.is_file(), f\"{x} file not found\"\n iou_types = [iou_types] if isinstance(iou_types, str) else iou_types\n suffix = [suffix] if isinstance(suffix, str) else suffix\n check_requirements(\"faster-coco-eval>=1.6.7\")\n from faster_coco_eval import COCO, COCOeval_faster\n\n anno = COCO(anno_json)\n pred = anno.loadRes(pred_json)\n for i, iou_type in enumerate(iou_types):\n val = COCOeval_faster(\n anno, pred, iouType=iou_type, lvis_style=self.is_lvis, print_function=LOGGER.info\n )\n val.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # images to eval\n val.evaluate()\n val.accumulate()\n val.summarize()\n\n # update mAP50-95 and mAP50\n stats[f\"metrics/mAP50({suffix[i][0]})\"] = val.stats_as_dict[\"AP_50\"]\n stats[f\"metrics/mAP50-95({suffix[i][0]})\"] = val.stats_as_dict[\"AP_all\"]\n\n if self.is_lvis:\n stats[f\"metrics/APr({suffix[i][0]})\"] = val.stats_as_dict[\"APr\"]\n stats[f\"metrics/APc({suffix[i][0]})\"] = val.stats_as_dict[\"APc\"]\n stats[f\"metrics/APf({suffix[i][0]})\"] = val.stats_as_dict[\"APf\"]\n\n if self.is_lvis:\n stats[\"fitness\"] = stats[\"metrics/mAP50-95(B)\"] # always use box mAP50-95 for fitness\n except Exception as e:\n LOGGER.warning(f\"faster-coco-eval unable to run: {e}\")\n return stats", "chunk_type": "class", "name": "DetectionValidator", "file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\val.py", "start_line": 18, "end_line": 465, "start_col": 0, "end_col": 20, "parent_name": null, "docstring": "A class extending the BaseValidator class for validation based on a detection model.\n\nThis class implements validation functionality specific to object detection tasks, including metrics calculation,\nprediction processing, and visualization of results.\n\nAttributes:\n is_coco (bool): Whether the dataset is COCO.\n is_lvis (bool): Whether the dataset is LVIS.\n class_map (List[int]): Mapping from model class indices to dataset class indices.\n metrics (DetMetrics): Object detection metrics calculator.\n iouv (torch.Tensor): IoU thresholds for mAP calculation.\n niou (int): Number of IoU thresholds.\n lb (List[Any]): List for storing ground truth labels for hybrid saving.\n jdict (List[Dict[str, Any]]): List for storing JSON detection results.\n stats (Dict[str, List[torch.Tensor]]): Dictionary for storing statistics during validation.\n\nExamples:\n >>> from ultralytics.models.yolo.detect import DetectionValidator\n >>> args = dict(model=\"yolo11n.pt\", data=\"coco8.yaml\")\n >>> validator = DetectionValidator(args=args)\n >>> validator()", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "os", "pathlib.Path", "typing.Any", "typing.Dict", "typing.List", "typing.Optional", "typing.Tuple", "typing.Union", "numpy", "torch", "ultralytics.data.build_dataloader", "ultralytics.data.build_yolo_dataset", "ultralytics.data.converter", "ultralytics.engine.validator.BaseValidator", "ultralytics.utils.LOGGER", "ultralytics.utils.ops", "ultralytics.utils.checks.check_requirements", "ultralytics.utils.metrics.ConfusionMatrix", "ultralytics.utils.metrics.DetMetrics", "ultralytics.utils.metrics.box_iou", "ultralytics.utils.plotting.plot_images", "ultralytics.engine.results.Results", "faster_coco_eval.COCO", "faster_coco_eval.COCOeval_faster", "BaseValidator" ], "chunk_id": "class_DetectionValidator_602f977a" }, { "content": "from .predict import DetectionPredictor", "chunk_type": "import", "name": "DetectionPredictor", "file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\__init__.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 39, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_DetectionPredictor_5b7f284f" }, { "content": "from .train import DetectionTrainer", "chunk_type": "import", "name": "DetectionTrainer", "file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\__init__.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 35, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_DetectionTrainer_1e8eb3ba" }, { "content": "from .val import DetectionValidator", "chunk_type": "import", "name": "DetectionValidator", "file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\__init__.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 35, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_DetectionValidator_a53db3c2" }, { "content": "__all__ = \"DetectionPredictor\", \"DetectionTrainer\", \"DetectionValidator\"", "chunk_type": "variable", "name": "__all__", "file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\__init__.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 72, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable___all___f412584c" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\ultralytics\\models\\yolo\\obb\\predict.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_aeec3423" }, { "content": "from ultralytics.engine.results import Results", "chunk_type": "import", "name": "Results", "file_path": "ultralytics\\ultralytics\\models\\yolo\\obb\\predict.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 46, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Results_2191c58a" }, { "content": "from ultralytics.models.yolo.detect.predict import DetectionPredictor", "chunk_type": "import", "name": "DetectionPredictor", "file_path": "ultralytics\\ultralytics\\models\\yolo\\obb\\predict.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 69, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_DetectionPredictor_ced97923" }, { "content": "from ultralytics.utils import DEFAULT_CFG, ops", "chunk_type": "import", "name": "DEFAULT_CFG, ops", "file_path": "ultralytics\\ultralytics\\models\\yolo\\obb\\predict.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 46, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_DEFAULT_CFG, ops_6bc7b389" }, { "content": "class OBBPredictor(DetectionPredictor):\n \"\"\"\n A class extending the DetectionPredictor class for prediction based on an Oriented Bounding Box (OBB) model.\n\n This predictor handles oriented bounding box detection tasks, processing images and returning results with rotated\n bounding boxes.\n\n Attributes:\n args (namespace): Configuration arguments for the predictor.\n model (torch.nn.Module): The loaded YOLO OBB model.\n\n Examples:\n >>> from ultralytics.utils import ASSETS\n >>> from ultralytics.models.yolo.obb import OBBPredictor\n >>> args = dict(model=\"yolo11n-obb.pt\", source=ASSETS)\n >>> predictor = OBBPredictor(overrides=args)\n >>> predictor.predict_cli()\n \"\"\"\n\n def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):\n \"\"\"\n Initialize OBBPredictor with optional model and data configuration overrides.\n\n Args:\n cfg (dict, optional): Default configuration for the predictor.\n overrides (dict, optional): Configuration overrides that take precedence over the default config.\n _callbacks (list, optional): List of callback functions to be invoked during prediction.\n\n Examples:\n >>> from ultralytics.utils import ASSETS\n >>> from ultralytics.models.yolo.obb import OBBPredictor\n >>> args = dict(model=\"yolo11n-obb.pt\", source=ASSETS)\n >>> predictor = OBBPredictor(overrides=args)\n \"\"\"\n super().__init__(cfg, overrides, _callbacks)\n self.args.task = \"obb\"\n\n def construct_result(self, pred, img, orig_img, img_path):\n \"\"\"\n Construct the result object from the prediction.\n\n Args:\n pred (torch.Tensor): The predicted bounding boxes, scores, and rotation angles with shape (N, 7) where\n the last dimension contains [x, y, w, h, confidence, class_id, angle].\n img (torch.Tensor): The image after preprocessing with shape (B, C, H, W).\n orig_img (np.ndarray): The original image before preprocessing.\n img_path (str): The path to the original image.\n\n Returns:\n (Results): The result object containing the original image, image path, class names, and oriented bounding\n boxes.\n \"\"\"\n rboxes = ops.regularize_rboxes(torch.cat([pred[:, :4], pred[:, -1:]], dim=-1))\n rboxes[:, :4] = ops.scale_boxes(img.shape[2:], rboxes[:, :4], orig_img.shape, xywh=True)\n obb = torch.cat([rboxes, pred[:, 4:6]], dim=-1)\n return Results(orig_img, path=img_path, names=self.model.names, obb=obb)", "chunk_type": "class", "name": "OBBPredictor", "file_path": "ultralytics\\ultralytics\\models\\yolo\\obb\\predict.py", "start_line": 10, "end_line": 65, "start_col": 0, "end_col": 80, "parent_name": null, "docstring": "A class extending the DetectionPredictor class for prediction based on an Oriented Bounding Box (OBB) model.\n\nThis predictor handles oriented bounding box detection tasks, processing images and returning results with rotated\nbounding boxes.\n\nAttributes:\n args (namespace): Configuration arguments for the predictor.\n model (torch.nn.Module): The loaded YOLO OBB model.\n\nExamples:\n >>> from ultralytics.utils import ASSETS\n >>> from ultralytics.models.yolo.obb import OBBPredictor\n >>> args = dict(model=\"yolo11n-obb.pt\", source=ASSETS)\n >>> predictor = OBBPredictor(overrides=args)\n >>> predictor.predict_cli()", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "torch", "ultralytics.engine.results.Results", "ultralytics.models.yolo.detect.predict.DetectionPredictor", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.ops", "DetectionPredictor" ], "chunk_id": "class_OBBPredictor_0d773c5c" }, { "content": "from copy import copy", "chunk_type": "import", "name": "copy", "file_path": "ultralytics\\ultralytics\\models\\yolo\\obb\\train.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 21, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_copy_570af0db" }, { "content": "from pathlib import Path", "chunk_type": "import", "name": "Path", "file_path": "ultralytics\\ultralytics\\models\\yolo\\obb\\train.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Path_8caf5746" }, { "content": "from typing import Any, List, Optional, Union", "chunk_type": "import", "name": "Any, List, Optional, Union", "file_path": "ultralytics\\ultralytics\\models\\yolo\\obb\\train.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 45, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Any, List, Optional, Union_5cedf4ff" }, { "content": "from ultralytics.models import yolo", "chunk_type": "import", "name": "yolo", "file_path": "ultralytics\\ultralytics\\models\\yolo\\obb\\train.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 35, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_yolo_7cdc48ea" }, { "content": "from ultralytics.nn.tasks import OBBModel", "chunk_type": "import", "name": "OBBModel", "file_path": "ultralytics\\ultralytics\\models\\yolo\\obb\\train.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 41, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_OBBModel_cab655dc" }, { "content": "from ultralytics.utils import DEFAULT_CFG, RANK", "chunk_type": "import", "name": "DEFAULT_CFG, RANK", "file_path": "ultralytics\\ultralytics\\models\\yolo\\obb\\train.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 47, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_DEFAULT_CFG, RANK_23eb45fd" }, { "content": "class OBBTrainer(yolo.detect.DetectionTrainer):\n \"\"\"\n A class extending the DetectionTrainer class for training based on an Oriented Bounding Box (OBB) model.\n\n This trainer specializes in training YOLO models that detect oriented bounding boxes, which are useful for\n detecting objects at arbitrary angles rather than just axis-aligned rectangles.\n\n Attributes:\n loss_names (tuple): Names of the loss components used during training including box_loss, cls_loss,\n and dfl_loss.\n\n Methods:\n get_model: Return OBBModel initialized with specified config and weights.\n get_validator: Return an instance of OBBValidator for validation of YOLO model.\n\n Examples:\n >>> from ultralytics.models.yolo.obb import OBBTrainer\n >>> args = dict(model=\"yolo11n-obb.pt\", data=\"dota8.yaml\", epochs=3)\n >>> trainer = OBBTrainer(overrides=args)\n >>> trainer.train()\n \"\"\"\n\n def __init__(self, cfg=DEFAULT_CFG, overrides: Optional[dict] = None, _callbacks: Optional[List[Any]] = None):\n \"\"\"\n Initialize an OBBTrainer object for training Oriented Bounding Box (OBB) models.\n\n This trainer extends the DetectionTrainer class to specialize in training models that detect oriented\n bounding boxes. It automatically sets the task to 'obb' in the configuration.\n\n Args:\n cfg (dict, optional): Configuration dictionary for the trainer. Contains training parameters and\n model configuration.\n overrides (dict, optional): Dictionary of parameter overrides for the configuration. Any values here\n will take precedence over those in cfg.\n _callbacks (List[Any], optional): List of callback functions to be invoked during training.\n\n Examples:\n >>> from ultralytics.models.yolo.obb import OBBTrainer\n >>> args = dict(model=\"yolo11n-obb.pt\", data=\"dota8.yaml\", epochs=3)\n >>> trainer = OBBTrainer(overrides=args)\n >>> trainer.train()\n \"\"\"\n if overrides is None:\n overrides = {}\n overrides[\"task\"] = \"obb\"\n super().__init__(cfg, overrides, _callbacks)\n\n def get_model(\n self, cfg: Optional[Union[str, dict]] = None, weights: Optional[Union[str, Path]] = None, verbose: bool = True\n ) -> OBBModel:\n \"\"\"\n Return OBBModel initialized with specified config and weights.\n\n Args:\n cfg (str | dict, optional): Model configuration. Can be a path to a YAML config file, a dictionary\n containing configuration parameters, or None to use default configuration.\n weights (str | Path, optional): Path to pretrained weights file. If None, random initialization is used.\n verbose (bool): Whether to display model information during initialization.\n\n Returns:\n (OBBModel): Initialized OBBModel with the specified configuration and weights.\n\n Examples:\n >>> trainer = OBBTrainer()\n >>> model = trainer.get_model(cfg=\"yolo11n-obb.yaml\", weights=\"yolo11n-obb.pt\")\n \"\"\"\n model = OBBModel(cfg, nc=self.data[\"nc\"], ch=self.data[\"channels\"], verbose=verbose and RANK == -1)\n if weights:\n model.load(weights)\n\n return model\n\n def get_validator(self):\n \"\"\"Return an instance of OBBValidator for validation of YOLO model.\"\"\"\n self.loss_names = \"box_loss\", \"cls_loss\", \"dfl_loss\"\n return yolo.obb.OBBValidator(\n self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks\n )", "chunk_type": "class", "name": "OBBTrainer", "file_path": "ultralytics\\ultralytics\\models\\yolo\\obb\\train.py", "start_line": 12, "end_line": 89, "start_col": 0, "end_col": 9, "parent_name": null, "docstring": "A class extending the DetectionTrainer class for training based on an Oriented Bounding Box (OBB) model.\n\nThis trainer specializes in training YOLO models that detect oriented bounding boxes, which are useful for\ndetecting objects at arbitrary angles rather than just axis-aligned rectangles.\n\nAttributes:\n loss_names (tuple): Names of the loss components used during training including box_loss, cls_loss,\n and dfl_loss.\n\nMethods:\n get_model: Return OBBModel initialized with specified config and weights.\n get_validator: Return an instance of OBBValidator for validation of YOLO model.\n\nExamples:\n >>> from ultralytics.models.yolo.obb import OBBTrainer\n >>> args = dict(model=\"yolo11n-obb.pt\", data=\"dota8.yaml\", epochs=3)\n >>> trainer = OBBTrainer(overrides=args)\n >>> trainer.train()", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "copy.copy", "pathlib.Path", "typing.Any", "typing.List", "typing.Optional", "typing.Union", "ultralytics.models.yolo", "ultralytics.nn.tasks.OBBModel", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.RANK", "yolo.detect.DetectionTrainer" ], "chunk_id": "class_OBBTrainer_efd0a4ee" }, { "content": "from pathlib import Path", "chunk_type": "import", "name": "Path", "file_path": "ultralytics\\ultralytics\\models\\yolo\\obb\\val.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Path_a1b3e038" }, { "content": "from typing import Any, Dict, List, Tuple, Union", "chunk_type": "import", "name": "Any, Dict, List, Tuple, Union", "file_path": "ultralytics\\ultralytics\\models\\yolo\\obb\\val.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 48, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Any, Dict, List, Tuple, Union_ebd9e595" }, { "content": "import numpy as np", "chunk_type": "import", "name": "numpy", "file_path": "ultralytics\\ultralytics\\models\\yolo\\obb\\val.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_numpy_e2f852ec" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\ultralytics\\models\\yolo\\obb\\val.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_c4d8299d" }, { "content": "from ultralytics.models.yolo.detect import DetectionValidator", "chunk_type": "import", "name": "DetectionValidator", "file_path": "ultralytics\\ultralytics\\models\\yolo\\obb\\val.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 61, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_DetectionValidator_2c9c1948" }, { "content": "from ultralytics.utils import LOGGER, ops", "chunk_type": "import", "name": "LOGGER, ops", "file_path": "ultralytics\\ultralytics\\models\\yolo\\obb\\val.py", "start_line": 10, "end_line": 10, "start_col": 0, "end_col": 41, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_LOGGER, ops_d88f0de0" }, { "content": "from ultralytics.utils.metrics import OBBMetrics, batch_probiou", "chunk_type": "import", "name": "OBBMetrics, batch_probiou", "file_path": "ultralytics\\ultralytics\\models\\yolo\\obb\\val.py", "start_line": 11, "end_line": 11, "start_col": 0, "end_col": 63, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_OBBMetrics, batch_probiou_2ab63936" }, { "content": "class OBBValidator(DetectionValidator):\n \"\"\"\n A class extending the DetectionValidator class for validation based on an Oriented Bounding Box (OBB) model.\n\n This validator specializes in evaluating models that predict rotated bounding boxes, commonly used for aerial and\n satellite imagery where objects can appear at various orientations.\n\n Attributes:\n args (dict): Configuration arguments for the validator.\n metrics (OBBMetrics): Metrics object for evaluating OBB model performance.\n is_dota (bool): Flag indicating whether the validation dataset is in DOTA format.\n\n Methods:\n init_metrics: Initialize evaluation metrics for YOLO.\n _process_batch: Process batch of detections and ground truth boxes to compute IoU matrix.\n _prepare_batch: Prepare batch data for OBB validation.\n _prepare_pred: Prepare predictions with scaled and padded bounding boxes.\n plot_predictions: Plot predicted bounding boxes on input images.\n pred_to_json: Serialize YOLO predictions to COCO json format.\n save_one_txt: Save YOLO detections to a txt file in normalized coordinates.\n eval_json: Evaluate YOLO output in JSON format and return performance statistics.\n\n Examples:\n >>> from ultralytics.models.yolo.obb import OBBValidator\n >>> args = dict(model=\"yolo11n-obb.pt\", data=\"dota8.yaml\")\n >>> validator = OBBValidator(args=args)\n >>> validator(model=args[\"model\"])\n \"\"\"\n\n def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:\n \"\"\"\n Initialize OBBValidator and set task to 'obb', metrics to OBBMetrics.\n\n This constructor initializes an OBBValidator instance for validating Oriented Bounding Box (OBB) models.\n It extends the DetectionValidator class and configures it specifically for the OBB task.\n\n Args:\n dataloader (torch.utils.data.DataLoader, optional): Dataloader to be used for validation.\n save_dir (str | Path, optional): Directory to save results.\n args (dict | SimpleNamespace, optional): Arguments containing validation parameters.\n _callbacks (list, optional): List of callback functions to be called during validation.\n \"\"\"\n super().__init__(dataloader, save_dir, args, _callbacks)\n self.args.task = \"obb\"\n self.metrics = OBBMetrics()\n\n def init_metrics(self, model: torch.nn.Module) -> None:\n \"\"\"\n Initialize evaluation metrics for YOLO obb validation.\n\n Args:\n model (torch.nn.Module): Model to validate.\n \"\"\"\n super().init_metrics(model)\n val = self.data.get(self.args.split, \"\") # validation path\n self.is_dota = isinstance(val, str) and \"DOTA\" in val # check if dataset is DOTA format\n\n def _process_batch(self, preds: Dict[str, torch.Tensor], batch: Dict[str, torch.Tensor]) -> Dict[str, np.ndarray]:\n \"\"\"\n Compute the correct prediction matrix for a batch of detections and ground truth bounding boxes.\n\n Args:\n preds (Dict[str, torch.Tensor]): Prediction dictionary containing 'cls' and 'bboxes' keys with detected\n class labels and bounding boxes.\n batch (Dict[str, torch.Tensor]): Batch dictionary containing 'cls' and 'bboxes' keys with ground truth\n class labels and bounding boxes.\n\n Returns:\n (Dict[str, np.ndarray]): Dictionary containing 'tp' key with the correct prediction matrix as a numpy\n array with shape (N, 10), which includes 10 IoU levels for each detection, indicating the accuracy\n of predictions compared to the ground truth.\n\n Examples:\n >>> detections = torch.rand(100, 7) # 100 sample detections\n >>> gt_bboxes = torch.rand(50, 5) # 50 sample ground truth boxes\n >>> gt_cls = torch.randint(0, 5, (50,)) # 50 ground truth class labels\n >>> correct_matrix = validator._process_batch(detections, gt_bboxes, gt_cls)\n \"\"\"\n if len(batch[\"cls\"]) == 0 or len(preds[\"cls\"]) == 0:\n return {\"tp\": np.zeros((len(preds[\"cls\"]), self.niou), dtype=bool)}\n iou = batch_probiou(batch[\"bboxes\"], preds[\"bboxes\"])\n return {\"tp\": self.match_predictions(preds[\"cls\"], batch[\"cls\"], iou).cpu().numpy()}\n\n def postprocess(self, preds: torch.Tensor) -> List[Dict[str, torch.Tensor]]:\n \"\"\"\n Args:\n preds (torch.Tensor): Raw predictions from the model.\n\n Returns:\n (List[Dict[str, torch.Tensor]]): Processed predictions with angle information concatenated to bboxes.\n \"\"\"\n preds = super().postprocess(preds)\n for pred in preds:\n pred[\"bboxes\"] = torch.cat([pred[\"bboxes\"], pred.pop(\"extra\")], dim=-1) # concatenate angle\n return preds\n\n def _prepare_batch(self, si: int, batch: Dict[str, Any]) -> Dict[str, Any]:\n \"\"\"\n Prepare batch data for OBB validation with proper scaling and formatting.\n\n Args:\n si (int): Batch index to process.\n batch (Dict[str, Any]): Dictionary containing batch data with keys:\n - batch_idx: Tensor of batch indices\n - cls: Tensor of class labels\n - bboxes: Tensor of bounding boxes\n - ori_shape: Original image shapes\n - img: Batch of images\n - ratio_pad: Ratio and padding information\n\n Returns:\n (Dict[str, Any]): Prepared batch data with scaled bounding boxes and metadata.\n \"\"\"\n idx = batch[\"batch_idx\"] == si\n cls = batch[\"cls\"][idx].squeeze(-1)\n bbox = batch[\"bboxes\"][idx]\n ori_shape = batch[\"ori_shape\"][si]\n imgsz = batch[\"img\"].shape[2:]\n ratio_pad = batch[\"ratio_pad\"][si]\n if len(cls):\n bbox[..., :4].mul_(torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]]) # target boxes\n ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad, xywh=True) # native-space labels\n return {\"cls\": cls, \"bboxes\": bbox, \"ori_shape\": ori_shape, \"imgsz\": imgsz, \"ratio_pad\": ratio_pad}\n\n def _prepare_pred(self, pred: Dict[str, torch.Tensor], pbatch: Dict[str, Any]) -> Dict[str, torch.Tensor]:\n \"\"\"\n Prepare predictions by scaling bounding boxes to original image dimensions.\n\n This method takes prediction tensors containing bounding box coordinates and scales them from the model's\n input dimensions to the original image dimensions using the provided batch information.\n\n Args:\n pred (Dict[str, torch.Tensor]): Prediction dictionary containing bounding box coordinates and other information.\n pbatch (Dict[str, Any]): Dictionary containing batch information with keys:\n - imgsz (tuple): Model input image size.\n - ori_shape (tuple): Original image shape.\n - ratio_pad (tuple): Ratio and padding information for scaling.\n\n Returns:\n (Dict[str, torch.Tensor]): Scaled prediction dictionary with bounding boxes in original image dimensions.\n \"\"\"\n cls = pred[\"cls\"]\n if self.args.single_cls:\n cls *= 0\n bboxes = ops.scale_boxes(\n pbatch[\"imgsz\"], pred[\"bboxes\"].clone(), pbatch[\"ori_shape\"], ratio_pad=pbatch[\"ratio_pad\"], xywh=True\n ) # native-space pred\n return {\"bboxes\": bboxes, \"conf\": pred[\"conf\"], \"cls\": cls}\n\n def plot_predictions(self, batch: Dict[str, Any], preds: List[torch.Tensor], ni: int) -> None:\n \"\"\"\n Plot predicted bounding boxes on input images and save the result.\n\n Args:\n batch (Dict[str, Any]): Batch data containing images, file paths, and other metadata.\n preds (List[torch.Tensor]): List of prediction tensors for each image in the batch.\n ni (int): Batch index used for naming the output file.\n\n Examples:\n >>> validator = OBBValidator()\n >>> batch = {\"img\": images, \"im_file\": paths}\n >>> preds = [torch.rand(10, 7)] # Example predictions for one image\n >>> validator.plot_predictions(batch, preds, 0)\n \"\"\"\n for p in preds:\n # TODO: fix this duplicated `xywh2xyxy`\n p[\"bboxes\"][:, :4] = ops.xywh2xyxy(p[\"bboxes\"][:, :4]) # convert to xyxy format for plotting\n super().plot_predictions(batch, preds, ni) # plot bboxes\n\n def pred_to_json(self, predn: Dict[str, torch.Tensor], filename: Union[str, Path]) -> None:\n \"\"\"\n Convert YOLO predictions to COCO JSON format with rotated bounding box information.\n\n Args:\n predn (Dict[str, torch.Tensor]): Prediction dictionary containing 'bboxes', 'conf', and 'cls' keys\n with bounding box coordinates, confidence scores, and class predictions.\n filename (str | Path): Path to the image file for which predictions are being processed.\n\n Notes:\n This method processes rotated bounding box predictions and converts them to both rbox format\n (x, y, w, h, angle) and polygon format (x1, y1, x2, y2, x3, y3, x4, y4) before adding them\n to the JSON dictionary.\n \"\"\"\n stem = Path(filename).stem\n image_id = int(stem) if stem.isnumeric() else stem\n rbox = predn[\"bboxes\"]\n poly = ops.xywhr2xyxyxyxy(rbox).view(-1, 8)\n for r, b, s, c in zip(rbox.tolist(), poly.tolist(), predn[\"conf\"].tolist(), predn[\"cls\"].tolist()):\n self.jdict.append(\n {\n \"image_id\": image_id,\n \"category_id\": self.class_map[int(c)],\n \"score\": round(s, 5),\n \"rbox\": [round(x, 3) for x in r],\n \"poly\": [round(x, 3) for x in b],\n }\n )\n\n def save_one_txt(self, predn: Dict[str, torch.Tensor], save_conf: bool, shape: Tuple[int, int], file: Path) -> None:\n \"\"\"\n Save YOLO OBB detections to a text file in normalized coordinates.\n\n Args:\n predn (torch.Tensor): Predicted detections with shape (N, 7) containing bounding boxes, confidence scores,\n class predictions, and angles in format (x, y, w, h, conf, cls, angle).\n save_conf (bool): Whether to save confidence scores in the text file.\n shape (Tuple[int, int]): Original image shape in format (height, width).\n file (Path): Output file path to save detections.\n\n Examples:\n >>> validator = OBBValidator()\n >>> predn = torch.tensor([[100, 100, 50, 30, 0.9, 0, 45]]) # One detection: x,y,w,h,conf,cls,angle\n >>> validator.save_one_txt(predn, True, (640, 480), \"detection.txt\")\n \"\"\"\n import numpy as np\n\n from ultralytics.engine.results import Results\n\n Results(\n np.zeros((shape[0], shape[1]), dtype=np.uint8),\n path=None,\n names=self.names,\n obb=torch.cat([predn[\"bboxes\"], predn[\"conf\"].unsqueeze(-1), predn[\"cls\"].unsqueeze(-1)], dim=1),\n ).save_txt(file, save_conf=save_conf)\n\n def eval_json(self, stats: Dict[str, Any]) -> Dict[str, Any]:\n \"\"\"\n Evaluate YOLO output in JSON format and save predictions in DOTA format.\n\n Args:\n stats (Dict[str, Any]): Performance statistics dictionary.\n\n Returns:\n (Dict[str, Any]): Updated performance statistics.\n \"\"\"\n if self.args.save_json and self.is_dota and len(self.jdict):\n import json\n import re\n from collections import defaultdict\n\n pred_json = self.save_dir / \"predictions.json\" # predictions\n pred_txt = self.save_dir / \"predictions_txt\" # predictions\n pred_txt.mkdir(parents=True, exist_ok=True)\n data = json.load(open(pred_json))\n # Save split results\n LOGGER.info(f\"Saving predictions with DOTA format to {pred_txt}...\")\n for d in data:\n image_id = d[\"image_id\"]\n score = d[\"score\"]\n classname = self.names[d[\"category_id\"] - 1].replace(\" \", \"-\")\n p = d[\"poly\"]\n\n with open(f\"{pred_txt / f'Task1_{classname}'}.txt\", \"a\", encoding=\"utf-8\") as f:\n f.writelines(f\"{image_id} {score} {p[0]} {p[1]} {p[2]} {p[3]} {p[4]} {p[5]} {p[6]} {p[7]}\\n\")\n # Save merged results, this could result slightly lower map than using official merging script,\n # because of the probiou calculation.\n pred_merged_txt = self.save_dir / \"predictions_merged_txt\" # predictions\n pred_merged_txt.mkdir(parents=True, exist_ok=True)\n merged_results = defaultdict(list)\n LOGGER.info(f\"Saving merged predictions with DOTA format to {pred_merged_txt}...\")\n for d in data:\n image_id = d[\"image_id\"].split(\"__\", 1)[0]\n pattern = re.compile(r\"\\d+___\\d+\")\n x, y = (int(c) for c in re.findall(pattern, d[\"image_id\"])[0].split(\"___\"))\n bbox, score, cls = d[\"rbox\"], d[\"score\"], d[\"category_id\"] - 1\n bbox[0] += x\n bbox[1] += y\n bbox.extend([score, cls])\n merged_results[image_id].append(bbox)\n for image_id, bbox in merged_results.items():\n bbox = torch.tensor(bbox)\n max_wh = torch.max(bbox[:, :2]).item() * 2\n c = bbox[:, 6:7] * max_wh # classes\n scores = bbox[:, 5] # scores\n b = bbox[:, :5].clone()\n b[:, :2] += c\n # 0.3 could get results close to the ones from official merging script, even slightly better.\n i = ops.nms_rotated(b, scores, 0.3)\n bbox = bbox[i]\n\n b = ops.xywhr2xyxyxyxy(bbox[:, :5]).view(-1, 8)\n for x in torch.cat([b, bbox[:, 5:7]], dim=-1).tolist():\n classname = self.names[int(x[-1])].replace(\" \", \"-\")\n p = [round(i, 3) for i in x[:-2]] # poly\n score = round(x[-2], 3)\n\n with open(f\"{pred_merged_txt / f'Task1_{classname}'}.txt\", \"a\", encoding=\"utf-8\") as f:\n f.writelines(f\"{image_id} {score} {p[0]} {p[1]} {p[2]} {p[3]} {p[4]} {p[5]} {p[6]} {p[7]}\\n\")\n\n return stats", "chunk_type": "class", "name": "OBBValidator", "file_path": "ultralytics\\ultralytics\\models\\yolo\\obb\\val.py", "start_line": 14, "end_line": 303, "start_col": 0, "end_col": 20, "parent_name": null, "docstring": "A class extending the DetectionValidator class for validation based on an Oriented Bounding Box (OBB) model.\n\nThis validator specializes in evaluating models that predict rotated bounding boxes, commonly used for aerial and\nsatellite imagery where objects can appear at various orientations.\n\nAttributes:\n args (dict): Configuration arguments for the validator.\n metrics (OBBMetrics): Metrics object for evaluating OBB model performance.\n is_dota (bool): Flag indicating whether the validation dataset is in DOTA format.\n\nMethods:\n init_metrics: Initialize evaluation metrics for YOLO.\n _process_batch: Process batch of detections and ground truth boxes to compute IoU matrix.\n _prepare_batch: Prepare batch data for OBB validation.\n _prepare_pred: Prepare predictions with scaled and padded bounding boxes.\n plot_predictions: Plot predicted bounding boxes on input images.\n pred_to_json: Serialize YOLO predictions to COCO json format.\n save_one_txt: Save YOLO detections to a txt file in normalized coordinates.\n eval_json: Evaluate YOLO output in JSON format and return performance statistics.\n\nExamples:\n >>> from ultralytics.models.yolo.obb import OBBValidator\n >>> args = dict(model=\"yolo11n-obb.pt\", data=\"dota8.yaml\")\n >>> validator = OBBValidator(args=args)\n >>> validator(model=args[\"model\"])", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "pathlib.Path", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "typing.Union", "numpy", "torch", "ultralytics.models.yolo.detect.DetectionValidator", "ultralytics.utils.LOGGER", "ultralytics.utils.ops", "ultralytics.utils.metrics.OBBMetrics", "ultralytics.utils.metrics.batch_probiou", "numpy", "ultralytics.engine.results.Results", "json", "re", "collections.defaultdict", "DetectionValidator" ], "chunk_id": "class_OBBValidator_afeefb6c" }, { "content": "from .predict import OBBPredictor", "chunk_type": "import", "name": "OBBPredictor", "file_path": "ultralytics\\ultralytics\\models\\yolo\\obb\\__init__.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 33, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_OBBPredictor_6b8ebe96" }, { "content": "from .train import OBBTrainer", "chunk_type": "import", "name": "OBBTrainer", "file_path": "ultralytics\\ultralytics\\models\\yolo\\obb\\__init__.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 29, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_OBBTrainer_ac16a03d" }, { "content": "from .val import OBBValidator", "chunk_type": "import", "name": "OBBValidator", "file_path": "ultralytics\\ultralytics\\models\\yolo\\obb\\__init__.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 29, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_OBBValidator_6794c3f5" }, { "content": "__all__ = \"OBBPredictor\", \"OBBTrainer\", \"OBBValidator\"", "chunk_type": "variable", "name": "__all__", "file_path": "ultralytics\\ultralytics\\models\\yolo\\obb\\__init__.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 54, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable___all___512c9b5a" }, { "content": "from ultralytics.models.yolo.detect.predict import DetectionPredictor", "chunk_type": "import", "name": "DetectionPredictor", "file_path": "ultralytics\\ultralytics\\models\\yolo\\pose\\predict.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 69, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_DetectionPredictor_083c2bc6" }, { "content": "from ultralytics.utils import DEFAULT_CFG, LOGGER, ops", "chunk_type": "import", "name": "DEFAULT_CFG, LOGGER, ops", "file_path": "ultralytics\\ultralytics\\models\\yolo\\pose\\predict.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 54, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_DEFAULT_CFG, LOGGER, ops_d2bc05fe" }, { "content": "class PosePredictor(DetectionPredictor):\n \"\"\"\n A class extending the DetectionPredictor class for prediction based on a pose model.\n\n This class specializes in pose estimation, handling keypoints detection alongside standard object detection\n capabilities inherited from DetectionPredictor.\n\n Attributes:\n args (namespace): Configuration arguments for the predictor.\n model (torch.nn.Module): The loaded YOLO pose model with keypoint detection capabilities.\n\n Methods:\n construct_result: Construct the result object from the prediction, including keypoints.\n\n Examples:\n >>> from ultralytics.utils import ASSETS\n >>> from ultralytics.models.yolo.pose import PosePredictor\n >>> args = dict(model=\"yolo11n-pose.pt\", source=ASSETS)\n >>> predictor = PosePredictor(overrides=args)\n >>> predictor.predict_cli()\n \"\"\"\n\n def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):\n \"\"\"\n Initialize PosePredictor for pose estimation tasks.\n\n Sets up a PosePredictor instance, configuring it for pose detection tasks and handling device-specific\n warnings for Apple MPS.\n\n Args:\n cfg (Any): Configuration for the predictor.\n overrides (dict, optional): Configuration overrides that take precedence over cfg.\n _callbacks (list, optional): List of callback functions to be invoked during prediction.\n\n Examples:\n >>> from ultralytics.utils import ASSETS\n >>> from ultralytics.models.yolo.pose import PosePredictor\n >>> args = dict(model=\"yolo11n-pose.pt\", source=ASSETS)\n >>> predictor = PosePredictor(overrides=args)\n >>> predictor.predict_cli()\n \"\"\"\n super().__init__(cfg, overrides, _callbacks)\n self.args.task = \"pose\"\n if isinstance(self.args.device, str) and self.args.device.lower() == \"mps\":\n LOGGER.warning(\n \"Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. \"\n \"See https://github.com/ultralytics/ultralytics/issues/4031.\"\n )\n\n def construct_result(self, pred, img, orig_img, img_path):\n \"\"\"\n Construct the result object from the prediction, including keypoints.\n\n Extends the parent class implementation by extracting keypoint data from predictions and adding them to the\n result object.\n\n Args:\n pred (torch.Tensor): The predicted bounding boxes, scores, and keypoints with shape (N, 6+K*D) where N is\n the number of detections, K is the number of keypoints, and D is the keypoint dimension.\n img (torch.Tensor): The processed input image tensor with shape (B, C, H, W).\n orig_img (np.ndarray): The original unprocessed image as a numpy array.\n img_path (str): The path to the original image file.\n\n Returns:\n (Results): The result object containing the original image, image path, class names, bounding boxes, and\n keypoints.\n \"\"\"\n result = super().construct_result(pred, img, orig_img, img_path)\n # Extract keypoints from prediction and reshape according to model's keypoint shape\n pred_kpts = pred[:, 6:].view(len(pred), *self.model.kpt_shape)\n # Scale keypoints coordinates to match the original image dimensions\n pred_kpts = ops.scale_coords(img.shape[2:], pred_kpts, orig_img.shape)\n result.update(keypoints=pred_kpts)\n return result", "chunk_type": "class", "name": "PosePredictor", "file_path": "ultralytics\\ultralytics\\models\\yolo\\pose\\predict.py", "start_line": 7, "end_line": 80, "start_col": 0, "end_col": 21, "parent_name": null, "docstring": "A class extending the DetectionPredictor class for prediction based on a pose model.\n\nThis class specializes in pose estimation, handling keypoints detection alongside standard object detection\ncapabilities inherited from DetectionPredictor.\n\nAttributes:\n args (namespace): Configuration arguments for the predictor.\n model (torch.nn.Module): The loaded YOLO pose model with keypoint detection capabilities.\n\nMethods:\n construct_result: Construct the result object from the prediction, including keypoints.\n\nExamples:\n >>> from ultralytics.utils import ASSETS\n >>> from ultralytics.models.yolo.pose import PosePredictor\n >>> args = dict(model=\"yolo11n-pose.pt\", source=ASSETS)\n >>> predictor = PosePredictor(overrides=args)\n >>> predictor.predict_cli()", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "ultralytics.models.yolo.detect.predict.DetectionPredictor", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.LOGGER", "ultralytics.utils.ops", "DetectionPredictor" ], "chunk_id": "class_PosePredictor_1d8d9515" }, { "content": "from copy import copy", "chunk_type": "import", "name": "copy", "file_path": "ultralytics\\ultralytics\\models\\yolo\\pose\\train.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 21, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_copy_cf2de10a" }, { "content": "from pathlib import Path", "chunk_type": "import", "name": "Path", "file_path": "ultralytics\\ultralytics\\models\\yolo\\pose\\train.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Path_2aaaa099" }, { "content": "from typing import Any, Dict, Optional, Union", "chunk_type": "import", "name": "Any, Dict, Optional, Union", "file_path": "ultralytics\\ultralytics\\models\\yolo\\pose\\train.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 45, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Any, Dict, Optional, Union_42f903a3" }, { "content": "from ultralytics.models import yolo", "chunk_type": "import", "name": "yolo", "file_path": "ultralytics\\ultralytics\\models\\yolo\\pose\\train.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 35, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_yolo_66810175" }, { "content": "from ultralytics.nn.tasks import PoseModel", "chunk_type": "import", "name": "PoseModel", "file_path": "ultralytics\\ultralytics\\models\\yolo\\pose\\train.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 42, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_PoseModel_5bf44d06" }, { "content": "from ultralytics.utils import DEFAULT_CFG, LOGGER", "chunk_type": "import", "name": "DEFAULT_CFG, LOGGER", "file_path": "ultralytics\\ultralytics\\models\\yolo\\pose\\train.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 49, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_DEFAULT_CFG, LOGGER_8a6b8b43" }, { "content": "from ultralytics.utils.plotting import plot_results", "chunk_type": "import", "name": "plot_results", "file_path": "ultralytics\\ultralytics\\models\\yolo\\pose\\train.py", "start_line": 10, "end_line": 10, "start_col": 0, "end_col": 51, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_plot_results_7414fa96" }, { "content": "class PoseTrainer(yolo.detect.DetectionTrainer):\n \"\"\"\n A class extending the DetectionTrainer class for training YOLO pose estimation models.\n\n This trainer specializes in handling pose estimation tasks, managing model training, validation, and visualization\n of pose keypoints alongside bounding boxes.\n\n Attributes:\n args (dict): Configuration arguments for training.\n model (PoseModel): The pose estimation model being trained.\n data (dict): Dataset configuration including keypoint shape information.\n loss_names (tuple): Names of the loss components used in training.\n\n Methods:\n get_model: Retrieve a pose estimation model with specified configuration.\n set_model_attributes: Set keypoints shape attribute on the model.\n get_validator: Create a validator instance for model evaluation.\n plot_training_samples: Visualize training samples with keypoints.\n plot_metrics: Generate and save training/validation metric plots.\n get_dataset: Retrieve the dataset and ensure it contains required kpt_shape key.\n\n Examples:\n >>> from ultralytics.models.yolo.pose import PoseTrainer\n >>> args = dict(model=\"yolo11n-pose.pt\", data=\"coco8-pose.yaml\", epochs=3)\n >>> trainer = PoseTrainer(overrides=args)\n >>> trainer.train()\n \"\"\"\n\n def __init__(self, cfg=DEFAULT_CFG, overrides: Optional[Dict[str, Any]] = None, _callbacks=None):\n \"\"\"\n Initialize a PoseTrainer object for training YOLO pose estimation models.\n\n This initializes a trainer specialized for pose estimation tasks, setting the task to 'pose' and\n handling specific configurations needed for keypoint detection models.\n\n Args:\n cfg (dict, optional): Default configuration dictionary containing training parameters.\n overrides (dict, optional): Dictionary of parameter overrides for the default configuration.\n _callbacks (list, optional): List of callback functions to be executed during training.\n\n Notes:\n This trainer will automatically set the task to 'pose' regardless of what is provided in overrides.\n A warning is issued when using Apple MPS device due to known bugs with pose models.\n\n Examples:\n >>> from ultralytics.models.yolo.pose import PoseTrainer\n >>> args = dict(model=\"yolo11n-pose.pt\", data=\"coco8-pose.yaml\", epochs=3)\n >>> trainer = PoseTrainer(overrides=args)\n >>> trainer.train()\n \"\"\"\n if overrides is None:\n overrides = {}\n overrides[\"task\"] = \"pose\"\n super().__init__(cfg, overrides, _callbacks)\n\n if isinstance(self.args.device, str) and self.args.device.lower() == \"mps\":\n LOGGER.warning(\n \"Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. \"\n \"See https://github.com/ultralytics/ultralytics/issues/4031.\"\n )\n\n def get_model(\n self,\n cfg: Optional[Union[str, Path, Dict[str, Any]]] = None,\n weights: Optional[Union[str, Path]] = None,\n verbose: bool = True,\n ) -> PoseModel:\n \"\"\"\n Get pose estimation model with specified configuration and weights.\n\n Args:\n cfg (str | Path | dict, optional): Model configuration file path or dictionary.\n weights (str | Path, optional): Path to the model weights file.\n verbose (bool): Whether to display model information.\n\n Returns:\n (PoseModel): Initialized pose estimation model.\n \"\"\"\n model = PoseModel(\n cfg, nc=self.data[\"nc\"], ch=self.data[\"channels\"], data_kpt_shape=self.data[\"kpt_shape\"], verbose=verbose\n )\n if weights:\n model.load(weights)\n\n return model\n\n def set_model_attributes(self):\n \"\"\"Set keypoints shape attribute of PoseModel.\"\"\"\n super().set_model_attributes()\n self.model.kpt_shape = self.data[\"kpt_shape\"]\n\n def get_validator(self):\n \"\"\"Return an instance of the PoseValidator class for validation.\"\"\"\n self.loss_names = \"box_loss\", \"pose_loss\", \"kobj_loss\", \"cls_loss\", \"dfl_loss\"\n return yolo.pose.PoseValidator(\n self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks\n )\n\n def plot_metrics(self):\n \"\"\"Plot training/validation metrics.\"\"\"\n plot_results(file=self.csv, pose=True, on_plot=self.on_plot) # save results.png\n\n def get_dataset(self) -> Dict[str, Any]:\n \"\"\"\n Retrieve the dataset and ensure it contains the required `kpt_shape` key.\n\n Returns:\n (dict): A dictionary containing the training/validation/test dataset and category names.\n\n Raises:\n KeyError: If the `kpt_shape` key is not present in the dataset.\n \"\"\"\n data = super().get_dataset()\n if \"kpt_shape\" not in data:\n raise KeyError(f\"No `kpt_shape` in the {self.args.data}. See https://docs.ultralytics.com/datasets/pose/\")\n return data", "chunk_type": "class", "name": "PoseTrainer", "file_path": "ultralytics\\ultralytics\\models\\yolo\\pose\\train.py", "start_line": 13, "end_line": 128, "start_col": 0, "end_col": 19, "parent_name": null, "docstring": "A class extending the DetectionTrainer class for training YOLO pose estimation models.\n\nThis trainer specializes in handling pose estimation tasks, managing model training, validation, and visualization\nof pose keypoints alongside bounding boxes.\n\nAttributes:\n args (dict): Configuration arguments for training.\n model (PoseModel): The pose estimation model being trained.\n data (dict): Dataset configuration including keypoint shape information.\n loss_names (tuple): Names of the loss components used in training.\n\nMethods:\n get_model: Retrieve a pose estimation model with specified configuration.\n set_model_attributes: Set keypoints shape attribute on the model.\n get_validator: Create a validator instance for model evaluation.\n plot_training_samples: Visualize training samples with keypoints.\n plot_metrics: Generate and save training/validation metric plots.\n get_dataset: Retrieve the dataset and ensure it contains required kpt_shape key.\n\nExamples:\n >>> from ultralytics.models.yolo.pose import PoseTrainer\n >>> args = dict(model=\"yolo11n-pose.pt\", data=\"coco8-pose.yaml\", epochs=3)\n >>> trainer = PoseTrainer(overrides=args)\n >>> trainer.train()", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "copy.copy", "pathlib.Path", "typing.Any", "typing.Dict", "typing.Optional", "typing.Union", "ultralytics.models.yolo", "ultralytics.nn.tasks.PoseModel", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.LOGGER", "ultralytics.utils.plotting.plot_results", "yolo.detect.DetectionTrainer" ], "chunk_id": "class_PoseTrainer_75b00860" }, { "content": "from pathlib import Path", "chunk_type": "import", "name": "Path", "file_path": "ultralytics\\ultralytics\\models\\yolo\\pose\\val.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Path_5106f23b" }, { "content": "from typing import Any, Dict, Tuple", "chunk_type": "import", "name": "Any, Dict, Tuple", "file_path": "ultralytics\\ultralytics\\models\\yolo\\pose\\val.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 35, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Any, Dict, Tuple_2193a97b" }, { "content": "import numpy as np", "chunk_type": "import", "name": "numpy", "file_path": "ultralytics\\ultralytics\\models\\yolo\\pose\\val.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_numpy_21b0730b" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\ultralytics\\models\\yolo\\pose\\val.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_35e58519" }, { "content": "from ultralytics.models.yolo.detect import DetectionValidator", "chunk_type": "import", "name": "DetectionValidator", "file_path": "ultralytics\\ultralytics\\models\\yolo\\pose\\val.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 61, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_DetectionValidator_d1720e2b" }, { "content": "from ultralytics.utils import LOGGER, ops", "chunk_type": "import", "name": "LOGGER, ops", "file_path": "ultralytics\\ultralytics\\models\\yolo\\pose\\val.py", "start_line": 10, "end_line": 10, "start_col": 0, "end_col": 41, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_LOGGER, ops_5b4a1692" }, { "content": "from ultralytics.utils.metrics import OKS_SIGMA, PoseMetrics, kpt_iou", "chunk_type": "import", "name": "OKS_SIGMA, PoseMetrics, kpt_iou", "file_path": "ultralytics\\ultralytics\\models\\yolo\\pose\\val.py", "start_line": 11, "end_line": 11, "start_col": 0, "end_col": 69, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_OKS_SIGMA, PoseMetrics, kpt_iou_0354a920" }, { "content": "class PoseValidator(DetectionValidator):\n \"\"\"\n A class extending the DetectionValidator class for validation based on a pose model.\n\n This validator is specifically designed for pose estimation tasks, handling keypoints and implementing\n specialized metrics for pose evaluation.\n\n Attributes:\n sigma (np.ndarray): Sigma values for OKS calculation, either OKS_SIGMA or ones divided by number of keypoints.\n kpt_shape (List[int]): Shape of the keypoints, typically [17, 3] for COCO format.\n args (dict): Arguments for the validator including task set to \"pose\".\n metrics (PoseMetrics): Metrics object for pose evaluation.\n\n Methods:\n preprocess: Preprocess batch by converting keypoints data to float and moving it to the device.\n get_desc: Return description of evaluation metrics in string format.\n init_metrics: Initialize pose estimation metrics for YOLO model.\n _prepare_batch: Prepare a batch for processing by converting keypoints to float and scaling to original\n dimensions.\n _prepare_pred: Prepare and scale keypoints in predictions for pose processing.\n _process_batch: Return correct prediction matrix by computing Intersection over Union (IoU) between\n detections and ground truth.\n plot_val_samples: Plot and save validation set samples with ground truth bounding boxes and keypoints.\n plot_predictions: Plot and save model predictions with bounding boxes and keypoints.\n save_one_txt: Save YOLO pose detections to a text file in normalized coordinates.\n pred_to_json: Convert YOLO predictions to COCO JSON format.\n eval_json: Evaluate object detection model using COCO JSON format.\n\n Examples:\n >>> from ultralytics.models.yolo.pose import PoseValidator\n >>> args = dict(model=\"yolo11n-pose.pt\", data=\"coco8-pose.yaml\")\n >>> validator = PoseValidator(args=args)\n >>> validator()\n \"\"\"\n\n def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:\n \"\"\"\n Initialize a PoseValidator object for pose estimation validation.\n\n This validator is specifically designed for pose estimation tasks, handling keypoints and implementing\n specialized metrics for pose evaluation.\n\n Args:\n dataloader (torch.utils.data.DataLoader, optional): Dataloader to be used for validation.\n save_dir (Path | str, optional): Directory to save results.\n args (dict, optional): Arguments for the validator including task set to \"pose\".\n _callbacks (list, optional): List of callback functions to be executed during validation.\n\n Examples:\n >>> from ultralytics.models.yolo.pose import PoseValidator\n >>> args = dict(model=\"yolo11n-pose.pt\", data=\"coco8-pose.yaml\")\n >>> validator = PoseValidator(args=args)\n >>> validator()\n\n Notes:\n This class extends DetectionValidator with pose-specific functionality. It initializes with sigma values\n for OKS calculation and sets up PoseMetrics for evaluation. A warning is displayed when using Apple MPS\n due to a known bug with pose models.\n \"\"\"\n super().__init__(dataloader, save_dir, args, _callbacks)\n self.sigma = None\n self.kpt_shape = None\n self.args.task = \"pose\"\n self.metrics = PoseMetrics()\n if isinstance(self.args.device, str) and self.args.device.lower() == \"mps\":\n LOGGER.warning(\n \"Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. \"\n \"See https://github.com/ultralytics/ultralytics/issues/4031.\"\n )\n\n def preprocess(self, batch: Dict[str, Any]) -> Dict[str, Any]:\n \"\"\"Preprocess batch by converting keypoints data to float and moving it to the device.\"\"\"\n batch = super().preprocess(batch)\n batch[\"keypoints\"] = batch[\"keypoints\"].to(self.device).float()\n return batch\n\n def get_desc(self) -> str:\n \"\"\"Return description of evaluation metrics in string format.\"\"\"\n return (\"%22s\" + \"%11s\" * 10) % (\n \"Class\",\n \"Images\",\n \"Instances\",\n \"Box(P\",\n \"R\",\n \"mAP50\",\n \"mAP50-95)\",\n \"Pose(P\",\n \"R\",\n \"mAP50\",\n \"mAP50-95)\",\n )\n\n def init_metrics(self, model: torch.nn.Module) -> None:\n \"\"\"\n Initialize evaluation metrics for YOLO pose validation.\n\n Args:\n model (torch.nn.Module): Model to validate.\n \"\"\"\n super().init_metrics(model)\n self.kpt_shape = self.data[\"kpt_shape\"]\n is_pose = self.kpt_shape == [17, 3]\n nkpt = self.kpt_shape[0]\n self.sigma = OKS_SIGMA if is_pose else np.ones(nkpt) / nkpt\n\n def postprocess(self, preds: torch.Tensor) -> Dict[str, torch.Tensor]:\n \"\"\"\n Postprocess YOLO predictions to extract and reshape keypoints for pose estimation.\n\n This method extends the parent class postprocessing by extracting keypoints from the 'extra'\n field of predictions and reshaping them according to the keypoint shape configuration.\n The keypoints are reshaped from a flattened format to the proper dimensional structure\n (typically [N, 17, 3] for COCO pose format).\n\n Args:\n preds (torch.Tensor): Raw prediction tensor from the YOLO pose model containing\n bounding boxes, confidence scores, class predictions, and keypoint data.\n\n Returns:\n (Dict[torch.Tensor]): Dict of processed prediction dictionaries, each containing:\n - 'bboxes': Bounding box coordinates\n - 'conf': Confidence scores\n - 'cls': Class predictions\n - 'keypoints': Reshaped keypoint coordinates with shape (-1, *self.kpt_shape)\n\n Note:\n If no keypoints are present in a prediction (empty keypoints), that prediction\n is skipped and continues to the next one. The keypoints are extracted from the\n 'extra' field which contains additional task-specific data beyond basic detection.\n \"\"\"\n preds = super().postprocess(preds)\n for pred in preds:\n pred[\"keypoints\"] = pred.pop(\"extra\").view(-1, *self.kpt_shape) # remove extra if exists\n return preds\n\n def _prepare_batch(self, si: int, batch: Dict[str, Any]) -> Dict[str, Any]:\n \"\"\"\n Prepare a batch for processing by converting keypoints to float and scaling to original dimensions.\n\n Args:\n si (int): Batch index.\n batch (Dict[str, Any]): Dictionary containing batch data with keys like 'keypoints', 'batch_idx', etc.\n\n Returns:\n (Dict[str, Any]): Prepared batch with keypoints scaled to original image dimensions.\n\n Notes:\n This method extends the parent class's _prepare_batch method by adding keypoint processing.\n Keypoints are scaled from normalized coordinates to original image dimensions.\n \"\"\"\n pbatch = super()._prepare_batch(si, batch)\n kpts = batch[\"keypoints\"][batch[\"batch_idx\"] == si]\n h, w = pbatch[\"imgsz\"]\n kpts = kpts.clone()\n kpts[..., 0] *= w\n kpts[..., 1] *= h\n kpts = ops.scale_coords(pbatch[\"imgsz\"], kpts, pbatch[\"ori_shape\"], ratio_pad=pbatch[\"ratio_pad\"])\n pbatch[\"keypoints\"] = kpts\n return pbatch\n\n def _prepare_pred(self, pred: Dict[str, Any], pbatch: Dict[str, Any]) -> Dict[str, Any]:\n \"\"\"\n Prepare and scale keypoints in predictions for pose processing.\n\n This method extends the parent class's _prepare_pred method to handle keypoint scaling. It first calls\n the parent method to get the basic prediction boxes, then extracts and scales the keypoint coordinates\n to match the original image dimensions.\n\n Args:\n pred (Dict[str, torch.Tensor]): Post-processed predictions from the model.\n pbatch (Dict[str, Any]): Processed batch dictionary containing image information including:\n - imgsz: Image size used for inference\n - ori_shape: Original image shape\n - ratio_pad: Ratio and padding information for coordinate scaling\n\n Returns:\n (Dict[str, Any]): Processed prediction dictionary with keypoints scaled to original image dimensions.\n \"\"\"\n predn = super()._prepare_pred(pred, pbatch)\n predn[\"keypoints\"] = ops.scale_coords(\n pbatch[\"imgsz\"], pred.get(\"keypoints\").clone(), pbatch[\"ori_shape\"], ratio_pad=pbatch[\"ratio_pad\"]\n )\n return predn\n\n def _process_batch(self, preds: Dict[str, torch.Tensor], batch: Dict[str, Any]) -> Dict[str, np.ndarray]:\n \"\"\"\n Return correct prediction matrix by computing Intersection over Union (IoU) between detections and ground truth.\n\n Args:\n preds (Dict[str, torch.Tensor]): Dictionary containing prediction data with keys 'cls' for class predictions\n and 'keypoints' for keypoint predictions.\n batch (Dict[str, Any]): Dictionary containing ground truth data with keys 'cls' for class labels,\n 'bboxes' for bounding boxes, and 'keypoints' for keypoint annotations.\n\n Returns:\n (Dict[str, np.ndarray]): Dictionary containing the correct prediction matrix including 'tp_p' for pose\n true positives across 10 IoU levels.\n\n Notes:\n `0.53` scale factor used in area computation is referenced from\n https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384.\n \"\"\"\n tp = super()._process_batch(preds, batch)\n gt_cls = batch[\"cls\"]\n if len(gt_cls) == 0 or len(preds[\"cls\"]) == 0:\n tp_p = np.zeros((len(preds[\"cls\"]), self.niou), dtype=bool)\n else:\n # `0.53` is from https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384\n area = ops.xyxy2xywh(batch[\"bboxes\"])[:, 2:].prod(1) * 0.53\n iou = kpt_iou(batch[\"keypoints\"], preds[\"keypoints\"], sigma=self.sigma, area=area)\n tp_p = self.match_predictions(preds[\"cls\"], gt_cls, iou).cpu().numpy()\n tp.update({\"tp_p\": tp_p}) # update tp with kpts IoU\n return tp\n\n def save_one_txt(self, predn: Dict[str, torch.Tensor], save_conf: bool, shape: Tuple[int, int], file: Path) -> None:\n \"\"\"\n Save YOLO pose detections to a text file in normalized coordinates.\n\n Args:\n predn (Dict[str, torch.Tensor]): Dictionary containing predictions with keys 'bboxes', 'conf', 'cls' and 'keypoints.\n save_conf (bool): Whether to save confidence scores.\n shape (Tuple[int, int]): Shape of the original image (height, width).\n file (Path): Output file path to save detections.\n\n Notes:\n The output format is: class_id x_center y_center width height confidence keypoints where keypoints are\n normalized (x, y, visibility) values for each point.\n \"\"\"\n from ultralytics.engine.results import Results\n\n Results(\n np.zeros((shape[0], shape[1]), dtype=np.uint8),\n path=None,\n names=self.names,\n boxes=torch.cat([predn[\"bboxes\"], predn[\"conf\"].unsqueeze(-1), predn[\"cls\"].unsqueeze(-1)], dim=1),\n keypoints=predn[\"keypoints\"],\n ).save_txt(file, save_conf=save_conf)\n\n def pred_to_json(self, predn: Dict[str, torch.Tensor], filename: str) -> None:\n \"\"\"\n Convert YOLO predictions to COCO JSON format.\n\n This method takes prediction tensors and a filename, converts the bounding boxes from YOLO format\n to COCO format, and appends the results to the internal JSON dictionary (self.jdict).\n\n Args:\n predn (Dict[str, torch.Tensor]): Prediction dictionary containing 'bboxes', 'conf', 'cls',\n and 'keypoints' tensors.\n filename (str): Path to the image file for which predictions are being processed.\n\n Notes:\n The method extracts the image ID from the filename stem (either as an integer if numeric, or as a string),\n converts bounding boxes from xyxy to xywh format, and adjusts coordinates from center to top-left corner\n before saving to the JSON dictionary.\n \"\"\"\n stem = Path(filename).stem\n image_id = int(stem) if stem.isnumeric() else stem\n box = ops.xyxy2xywh(predn[\"bboxes\"]) # xywh\n box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner\n for b, s, c, k in zip(\n box.tolist(),\n predn[\"conf\"].tolist(),\n predn[\"cls\"].tolist(),\n predn[\"keypoints\"].flatten(1, 2).tolist(),\n ):\n self.jdict.append(\n {\n \"image_id\": image_id,\n \"category_id\": self.class_map[int(c)],\n \"bbox\": [round(x, 3) for x in b],\n \"keypoints\": k,\n \"score\": round(s, 5),\n }\n )\n\n def eval_json(self, stats: Dict[str, Any]) -> Dict[str, Any]:\n \"\"\"Evaluate object detection model using COCO JSON format.\"\"\"\n anno_json = self.data[\"path\"] / \"annotations/person_keypoints_val2017.json\" # annotations\n pred_json = self.save_dir / \"predictions.json\" # predictions\n return super().coco_evaluate(stats, pred_json, anno_json, [\"bbox\", \"keypoints\"], suffix=[\"Box\", \"Pose\"])", "chunk_type": "class", "name": "PoseValidator", "file_path": "ultralytics\\ultralytics\\models\\yolo\\pose\\val.py", "start_line": 14, "end_line": 293, "start_col": 0, "end_col": 112, "parent_name": null, "docstring": "A class extending the DetectionValidator class for validation based on a pose model.\n\nThis validator is specifically designed for pose estimation tasks, handling keypoints and implementing\nspecialized metrics for pose evaluation.\n\nAttributes:\n sigma (np.ndarray): Sigma values for OKS calculation, either OKS_SIGMA or ones divided by number of keypoints.\n kpt_shape (List[int]): Shape of the keypoints, typically [17, 3] for COCO format.\n args (dict): Arguments for the validator including task set to \"pose\".\n metrics (PoseMetrics): Metrics object for pose evaluation.\n\nMethods:\n preprocess: Preprocess batch by converting keypoints data to float and moving it to the device.\n get_desc: Return description of evaluation metrics in string format.\n init_metrics: Initialize pose estimation metrics for YOLO model.\n _prepare_batch: Prepare a batch for processing by converting keypoints to float and scaling to original\n dimensions.\n _prepare_pred: Prepare and scale keypoints in predictions for pose processing.\n _process_batch: Return correct prediction matrix by computing Intersection over Union (IoU) between\n detections and ground truth.\n plot_val_samples: Plot and save validation set samples with ground truth bounding boxes and keypoints.\n plot_predictions: Plot and save model predictions with bounding boxes and keypoints.\n save_one_txt: Save YOLO pose detections to a text file in normalized coordinates.\n pred_to_json: Convert YOLO predictions to COCO JSON format.\n eval_json: Evaluate object detection model using COCO JSON format.\n\nExamples:\n >>> from ultralytics.models.yolo.pose import PoseValidator\n >>> args = dict(model=\"yolo11n-pose.pt\", data=\"coco8-pose.yaml\")\n >>> validator = PoseValidator(args=args)\n >>> validator()", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "pathlib.Path", "typing.Any", "typing.Dict", "typing.Tuple", "numpy", "torch", "ultralytics.models.yolo.detect.DetectionValidator", "ultralytics.utils.LOGGER", "ultralytics.utils.ops", "ultralytics.utils.metrics.OKS_SIGMA", "ultralytics.utils.metrics.PoseMetrics", "ultralytics.utils.metrics.kpt_iou", "ultralytics.engine.results.Results", "DetectionValidator" ], "chunk_id": "class_PoseValidator_b7ce6f5d" }, { "content": "from .predict import PosePredictor", "chunk_type": "import", "name": "PosePredictor", "file_path": "ultralytics\\ultralytics\\models\\yolo\\pose\\__init__.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 34, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_PosePredictor_c2d3074f" }, { "content": "from .train import PoseTrainer", "chunk_type": "import", "name": "PoseTrainer", "file_path": "ultralytics\\ultralytics\\models\\yolo\\pose\\__init__.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 30, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_PoseTrainer_84be1717" }, { "content": "from .val import PoseValidator", "chunk_type": "import", "name": "PoseValidator", "file_path": "ultralytics\\ultralytics\\models\\yolo\\pose\\__init__.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 30, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_PoseValidator_f05dff60" }, { "content": "__all__ = \"PoseTrainer\", \"PoseValidator\", \"PosePredictor\"", "chunk_type": "variable", "name": "__all__", "file_path": "ultralytics\\ultralytics\\models\\yolo\\pose\\__init__.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 57, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable___all___a8d8fd20" }, { "content": "from ultralytics.engine.results import Results", "chunk_type": "import", "name": "Results", "file_path": "ultralytics\\ultralytics\\models\\yolo\\segment\\predict.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 46, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Results_1f6f7c59" }, { "content": "from ultralytics.models.yolo.detect.predict import DetectionPredictor", "chunk_type": "import", "name": "DetectionPredictor", "file_path": "ultralytics\\ultralytics\\models\\yolo\\segment\\predict.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 69, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_DetectionPredictor_6d55ce8c" }, { "content": "from ultralytics.utils import DEFAULT_CFG, ops", "chunk_type": "import", "name": "DEFAULT_CFG, ops", "file_path": "ultralytics\\ultralytics\\models\\yolo\\segment\\predict.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 46, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_DEFAULT_CFG, ops_c9819379" }, { "content": "class SegmentationPredictor(DetectionPredictor):\n \"\"\"\n A class extending the DetectionPredictor class for prediction based on a segmentation model.\n\n This class specializes in processing segmentation model outputs, handling both bounding boxes and masks in the\n prediction results.\n\n Attributes:\n args (dict): Configuration arguments for the predictor.\n model (torch.nn.Module): The loaded YOLO segmentation model.\n batch (list): Current batch of images being processed.\n\n Methods:\n postprocess: Apply non-max suppression and process segmentation detections.\n construct_results: Construct a list of result objects from predictions.\n construct_result: Construct a single result object from a prediction.\n\n Examples:\n >>> from ultralytics.utils import ASSETS\n >>> from ultralytics.models.yolo.segment import SegmentationPredictor\n >>> args = dict(model=\"yolo11n-seg.pt\", source=ASSETS)\n >>> predictor = SegmentationPredictor(overrides=args)\n >>> predictor.predict_cli()\n \"\"\"\n\n def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):\n \"\"\"\n Initialize the SegmentationPredictor with configuration, overrides, and callbacks.\n\n This class specializes in processing segmentation model outputs, handling both bounding boxes and masks in the\n prediction results.\n\n Args:\n cfg (dict): Configuration for the predictor.\n overrides (dict, optional): Configuration overrides that take precedence over cfg.\n _callbacks (list, optional): List of callback functions to be invoked during prediction.\n \"\"\"\n super().__init__(cfg, overrides, _callbacks)\n self.args.task = \"segment\"\n\n def postprocess(self, preds, img, orig_imgs):\n \"\"\"\n Apply non-max suppression and process segmentation detections for each image in the input batch.\n\n Args:\n preds (tuple): Model predictions, containing bounding boxes, scores, classes, and mask coefficients.\n img (torch.Tensor): Input image tensor in model format, with shape (B, C, H, W).\n orig_imgs (list | torch.Tensor | np.ndarray): Original image or batch of images.\n\n Returns:\n (list): List of Results objects containing the segmentation predictions for each image in the batch.\n Each Results object includes both bounding boxes and segmentation masks.\n\n Examples:\n >>> predictor = SegmentationPredictor(overrides=dict(model=\"yolo11n-seg.pt\"))\n >>> results = predictor.postprocess(preds, img, orig_img)\n \"\"\"\n # Extract protos - tuple if PyTorch model or array if exported\n protos = preds[1][-1] if isinstance(preds[1], tuple) else preds[1]\n return super().postprocess(preds[0], img, orig_imgs, protos=protos)\n\n def construct_results(self, preds, img, orig_imgs, protos):\n \"\"\"\n Construct a list of result objects from the predictions.\n\n Args:\n preds (List[torch.Tensor]): List of predicted bounding boxes, scores, and masks.\n img (torch.Tensor): The image after preprocessing.\n orig_imgs (List[np.ndarray]): List of original images before preprocessing.\n protos (List[torch.Tensor]): List of prototype masks.\n\n Returns:\n (List[Results]): List of result objects containing the original images, image paths, class names,\n bounding boxes, and masks.\n \"\"\"\n return [\n self.construct_result(pred, img, orig_img, img_path, proto)\n for pred, orig_img, img_path, proto in zip(preds, orig_imgs, self.batch[0], protos)\n ]\n\n def construct_result(self, pred, img, orig_img, img_path, proto):\n \"\"\"\n Construct a single result object from the prediction.\n\n Args:\n pred (np.ndarray): The predicted bounding boxes, scores, and masks.\n img (torch.Tensor): The image after preprocessing.\n orig_img (np.ndarray): The original image before preprocessing.\n img_path (str): The path to the original image.\n proto (torch.Tensor): The prototype masks.\n\n Returns:\n (Results): Result object containing the original image, image path, class names, bounding boxes, and masks.\n \"\"\"\n if not len(pred): # save empty boxes\n masks = None\n elif self.args.retina_masks:\n pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)\n masks = ops.process_mask_native(proto, pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC\n else:\n masks = ops.process_mask(proto, pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC\n pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)\n if masks is not None:\n keep = masks.sum((-2, -1)) > 0 # only keep predictions with masks\n pred, masks = pred[keep], masks[keep]\n return Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks)", "chunk_type": "class", "name": "SegmentationPredictor", "file_path": "ultralytics\\ultralytics\\models\\yolo\\segment\\predict.py", "start_line": 8, "end_line": 113, "start_col": 0, "end_col": 103, "parent_name": null, "docstring": "A class extending the DetectionPredictor class for prediction based on a segmentation model.\n\nThis class specializes in processing segmentation model outputs, handling both bounding boxes and masks in the\nprediction results.\n\nAttributes:\n args (dict): Configuration arguments for the predictor.\n model (torch.nn.Module): The loaded YOLO segmentation model.\n batch (list): Current batch of images being processed.\n\nMethods:\n postprocess: Apply non-max suppression and process segmentation detections.\n construct_results: Construct a list of result objects from predictions.\n construct_result: Construct a single result object from a prediction.\n\nExamples:\n >>> from ultralytics.utils import ASSETS\n >>> from ultralytics.models.yolo.segment import SegmentationPredictor\n >>> args = dict(model=\"yolo11n-seg.pt\", source=ASSETS)\n >>> predictor = SegmentationPredictor(overrides=args)\n >>> predictor.predict_cli()", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "ultralytics.engine.results.Results", "ultralytics.models.yolo.detect.predict.DetectionPredictor", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.ops", "DetectionPredictor" ], "chunk_id": "class_SegmentationPredictor_bd7c5e5e" }, { "content": "from copy import copy", "chunk_type": "import", "name": "copy", "file_path": "ultralytics\\ultralytics\\models\\yolo\\segment\\train.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 21, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_copy_caaa5ce4" }, { "content": "from pathlib import Path", "chunk_type": "import", "name": "Path", "file_path": "ultralytics\\ultralytics\\models\\yolo\\segment\\train.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Path_dbbdf7a2" }, { "content": "from typing import Dict, Optional, Union", "chunk_type": "import", "name": "Dict, Optional, Union", "file_path": "ultralytics\\ultralytics\\models\\yolo\\segment\\train.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 40, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Dict, Optional, Union_bbea9a10" }, { "content": "from ultralytics.models import yolo", "chunk_type": "import", "name": "yolo", "file_path": "ultralytics\\ultralytics\\models\\yolo\\segment\\train.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 35, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_yolo_310640d9" }, { "content": "from ultralytics.nn.tasks import SegmentationModel", "chunk_type": "import", "name": "SegmentationModel", "file_path": "ultralytics\\ultralytics\\models\\yolo\\segment\\train.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 50, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_SegmentationModel_54884489" }, { "content": "from ultralytics.utils import DEFAULT_CFG, RANK", "chunk_type": "import", "name": "DEFAULT_CFG, RANK", "file_path": "ultralytics\\ultralytics\\models\\yolo\\segment\\train.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 47, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_DEFAULT_CFG, RANK_62644128" }, { "content": "from ultralytics.utils.plotting import plot_results", "chunk_type": "import", "name": "plot_results", "file_path": "ultralytics\\ultralytics\\models\\yolo\\segment\\train.py", "start_line": 10, "end_line": 10, "start_col": 0, "end_col": 51, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_plot_results_70bc1128" }, { "content": "class SegmentationTrainer(yolo.detect.DetectionTrainer):\n \"\"\"\n A class extending the DetectionTrainer class for training based on a segmentation model.\n\n This trainer specializes in handling segmentation tasks, extending the detection trainer with segmentation-specific\n functionality including model initialization, validation, and visualization.\n\n Attributes:\n loss_names (Tuple[str]): Names of the loss components used during training.\n\n Examples:\n >>> from ultralytics.models.yolo.segment import SegmentationTrainer\n >>> args = dict(model=\"yolo11n-seg.pt\", data=\"coco8-seg.yaml\", epochs=3)\n >>> trainer = SegmentationTrainer(overrides=args)\n >>> trainer.train()\n \"\"\"\n\n def __init__(self, cfg=DEFAULT_CFG, overrides: Optional[Dict] = None, _callbacks=None):\n \"\"\"\n Initialize a SegmentationTrainer object.\n\n This initializes a trainer for segmentation tasks, extending the detection trainer with segmentation-specific\n functionality. It sets the task to 'segment' and prepares the trainer for training segmentation models.\n\n Args:\n cfg (dict): Configuration dictionary with default training settings.\n overrides (dict, optional): Dictionary of parameter overrides for the default configuration.\n _callbacks (list, optional): List of callback functions to be executed during training.\n\n Examples:\n >>> from ultralytics.models.yolo.segment import SegmentationTrainer\n >>> args = dict(model=\"yolo11n-seg.pt\", data=\"coco8-seg.yaml\", epochs=3)\n >>> trainer = SegmentationTrainer(overrides=args)\n >>> trainer.train()\n \"\"\"\n if overrides is None:\n overrides = {}\n overrides[\"task\"] = \"segment\"\n super().__init__(cfg, overrides, _callbacks)\n\n def get_model(\n self, cfg: Optional[Union[Dict, str]] = None, weights: Optional[Union[str, Path]] = None, verbose: bool = True\n ):\n \"\"\"\n Initialize and return a SegmentationModel with specified configuration and weights.\n\n Args:\n cfg (dict | str, optional): Model configuration. Can be a dictionary, a path to a YAML file, or None.\n weights (str | Path, optional): Path to pretrained weights file.\n verbose (bool): Whether to display model information during initialization.\n\n Returns:\n (SegmentationModel): Initialized segmentation model with loaded weights if specified.\n\n Examples:\n >>> trainer = SegmentationTrainer()\n >>> model = trainer.get_model(cfg=\"yolo11n-seg.yaml\")\n >>> model = trainer.get_model(weights=\"yolo11n-seg.pt\", verbose=False)\n \"\"\"\n model = SegmentationModel(cfg, nc=self.data[\"nc\"], ch=self.data[\"channels\"], verbose=verbose and RANK == -1)\n if weights:\n model.load(weights)\n\n return model\n\n def get_validator(self):\n \"\"\"Return an instance of SegmentationValidator for validation of YOLO model.\"\"\"\n self.loss_names = \"box_loss\", \"seg_loss\", \"cls_loss\", \"dfl_loss\"\n return yolo.segment.SegmentationValidator(\n self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks\n )\n\n def plot_metrics(self):\n \"\"\"Plot training/validation metrics.\"\"\"\n plot_results(file=self.csv, segment=True, on_plot=self.on_plot) # save results.png", "chunk_type": "class", "name": "SegmentationTrainer", "file_path": "ultralytics\\ultralytics\\models\\yolo\\segment\\train.py", "start_line": 13, "end_line": 87, "start_col": 0, "end_col": 71, "parent_name": null, "docstring": "A class extending the DetectionTrainer class for training based on a segmentation model.\n\nThis trainer specializes in handling segmentation tasks, extending the detection trainer with segmentation-specific\nfunctionality including model initialization, validation, and visualization.\n\nAttributes:\n loss_names (Tuple[str]): Names of the loss components used during training.\n\nExamples:\n >>> from ultralytics.models.yolo.segment import SegmentationTrainer\n >>> args = dict(model=\"yolo11n-seg.pt\", data=\"coco8-seg.yaml\", epochs=3)\n >>> trainer = SegmentationTrainer(overrides=args)\n >>> trainer.train()", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "copy.copy", "pathlib.Path", "typing.Dict", "typing.Optional", "typing.Union", "ultralytics.models.yolo", "ultralytics.nn.tasks.SegmentationModel", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.RANK", "ultralytics.utils.plotting.plot_results", "yolo.detect.DetectionTrainer" ], "chunk_id": "class_SegmentationTrainer_cbc8899e" }, { "content": "from multiprocessing.pool import ThreadPool", "chunk_type": "import", "name": "ThreadPool", "file_path": "ultralytics\\ultralytics\\models\\yolo\\segment\\val.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 43, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_ThreadPool_06f926f7" }, { "content": "from pathlib import Path", "chunk_type": "import", "name": "Path", "file_path": "ultralytics\\ultralytics\\models\\yolo\\segment\\val.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Path_d8db4424" }, { "content": "from typing import Any, Dict, List, Tuple", "chunk_type": "import", "name": "Any, Dict, List, Tuple", "file_path": "ultralytics\\ultralytics\\models\\yolo\\segment\\val.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 41, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Any, Dict, List, Tuple_48b48602" }, { "content": "import numpy as np", "chunk_type": "import", "name": "numpy", "file_path": "ultralytics\\ultralytics\\models\\yolo\\segment\\val.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_numpy_f1554a43" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\ultralytics\\models\\yolo\\segment\\val.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_b3575745" }, { "content": "import torch.nn.functional as F", "chunk_type": "import", "name": "torch.nn.functional", "file_path": "ultralytics\\ultralytics\\models\\yolo\\segment\\val.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 31, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch.nn.functional_660eeb7e" }, { "content": "from ultralytics.models.yolo.detect import DetectionValidator", "chunk_type": "import", "name": "DetectionValidator", "file_path": "ultralytics\\ultralytics\\models\\yolo\\segment\\val.py", "start_line": 11, "end_line": 11, "start_col": 0, "end_col": 61, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_DetectionValidator_eddf39dd" }, { "content": "from ultralytics.utils import LOGGER, NUM_THREADS, ops", "chunk_type": "import", "name": "LOGGER, NUM_THREADS, ops", "file_path": "ultralytics\\ultralytics\\models\\yolo\\segment\\val.py", "start_line": 12, "end_line": 12, "start_col": 0, "end_col": 54, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_LOGGER, NUM_THREADS, ops_5c13f41b" }, { "content": "from ultralytics.utils.checks import check_requirements", "chunk_type": "import", "name": "check_requirements", "file_path": "ultralytics\\ultralytics\\models\\yolo\\segment\\val.py", "start_line": 13, "end_line": 13, "start_col": 0, "end_col": 55, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_check_requirements_40a3b414" }, { "content": "from ultralytics.utils.metrics import SegmentMetrics, mask_iou", "chunk_type": "import", "name": "SegmentMetrics, mask_iou", "file_path": "ultralytics\\ultralytics\\models\\yolo\\segment\\val.py", "start_line": 14, "end_line": 14, "start_col": 0, "end_col": 62, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_SegmentMetrics, mask_iou_743469aa" }, { "content": "class SegmentationValidator(DetectionValidator):\n \"\"\"\n A class extending the DetectionValidator class for validation based on a segmentation model.\n\n This validator handles the evaluation of segmentation models, processing both bounding box and mask predictions\n to compute metrics such as mAP for both detection and segmentation tasks.\n\n Attributes:\n plot_masks (list): List to store masks for plotting.\n process (callable): Function to process masks based on save_json and save_txt flags.\n args (namespace): Arguments for the validator.\n metrics (SegmentMetrics): Metrics calculator for segmentation tasks.\n stats (dict): Dictionary to store statistics during validation.\n\n Examples:\n >>> from ultralytics.models.yolo.segment import SegmentationValidator\n >>> args = dict(model=\"yolo11n-seg.pt\", data=\"coco8-seg.yaml\")\n >>> validator = SegmentationValidator(args=args)\n >>> validator()\n \"\"\"\n\n def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:\n \"\"\"\n Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics.\n\n Args:\n dataloader (torch.utils.data.DataLoader, optional): Dataloader to use for validation.\n save_dir (Path, optional): Directory to save results.\n args (namespace, optional): Arguments for the validator.\n _callbacks (list, optional): List of callback functions.\n \"\"\"\n super().__init__(dataloader, save_dir, args, _callbacks)\n self.process = None\n self.args.task = \"segment\"\n self.metrics = SegmentMetrics()\n\n def preprocess(self, batch: Dict[str, Any]) -> Dict[str, Any]:\n \"\"\"\n Preprocess batch of images for YOLO segmentation validation.\n\n Args:\n batch (Dict[str, Any]): Batch containing images and annotations.\n\n Returns:\n (Dict[str, Any]): Preprocessed batch.\n \"\"\"\n batch = super().preprocess(batch)\n batch[\"masks\"] = batch[\"masks\"].to(self.device).float()\n return batch\n\n def init_metrics(self, model: torch.nn.Module) -> None:\n \"\"\"\n Initialize metrics and select mask processing function based on save_json flag.\n\n Args:\n model (torch.nn.Module): Model to validate.\n \"\"\"\n super().init_metrics(model)\n if self.args.save_json:\n check_requirements(\"faster-coco-eval>=1.6.7\")\n # More accurate vs faster\n self.process = ops.process_mask_native if self.args.save_json or self.args.save_txt else ops.process_mask\n\n def get_desc(self) -> str:\n \"\"\"Return a formatted description of evaluation metrics.\"\"\"\n return (\"%22s\" + \"%11s\" * 10) % (\n \"Class\",\n \"Images\",\n \"Instances\",\n \"Box(P\",\n \"R\",\n \"mAP50\",\n \"mAP50-95)\",\n \"Mask(P\",\n \"R\",\n \"mAP50\",\n \"mAP50-95)\",\n )\n\n def postprocess(self, preds: List[torch.Tensor]) -> List[Dict[str, torch.Tensor]]:\n \"\"\"\n Post-process YOLO predictions and return output detections with proto.\n\n Args:\n preds (List[torch.Tensor]): Raw predictions from the model.\n\n Returns:\n List[Dict[str, torch.Tensor]]: Processed detection predictions with masks.\n \"\"\"\n proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported\n preds = super().postprocess(preds[0])\n imgsz = [4 * x for x in proto.shape[2:]] # get image size from proto\n for i, pred in enumerate(preds):\n coefficient = pred.pop(\"extra\")\n pred[\"masks\"] = (\n self.process(proto[i], coefficient, pred[\"bboxes\"], shape=imgsz)\n if len(coefficient)\n else torch.zeros(\n (0, *(imgsz if self.process is ops.process_mask_native else proto.shape[2:])),\n dtype=torch.uint8,\n device=pred[\"bboxes\"].device,\n )\n )\n return preds\n\n def _prepare_batch(self, si: int, batch: Dict[str, Any]) -> Dict[str, Any]:\n \"\"\"\n Prepare a batch for training or inference by processing images and targets.\n\n Args:\n si (int): Batch index.\n batch (Dict[str, Any]): Batch data containing images and annotations.\n\n Returns:\n (Dict[str, Any]): Prepared batch with processed annotations.\n \"\"\"\n prepared_batch = super()._prepare_batch(si, batch)\n midx = [si] if self.args.overlap_mask else batch[\"batch_idx\"] == si\n prepared_batch[\"masks\"] = batch[\"masks\"][midx]\n return prepared_batch\n\n def _prepare_pred(self, pred: Dict[str, torch.Tensor], pbatch: Dict[str, Any]) -> Dict[str, torch.Tensor]:\n \"\"\"\n Prepare predictions for evaluation by processing bounding boxes and masks.\n\n Args:\n pred (Dict[str, torch.Tensor]): Post-processed predictions from the model.\n pbatch (Dict[str, Any]): Prepared batch information.\n\n Returns:\n Dict[str, torch.Tensor]: Processed bounding box predictions.\n \"\"\"\n predn = super()._prepare_pred(pred, pbatch)\n predn[\"masks\"] = pred[\"masks\"]\n if self.args.save_json and len(predn[\"masks\"]):\n coco_masks = torch.as_tensor(pred[\"masks\"], dtype=torch.uint8)\n coco_masks = ops.scale_image(\n coco_masks.permute(1, 2, 0).contiguous().cpu().numpy(),\n pbatch[\"ori_shape\"],\n ratio_pad=pbatch[\"ratio_pad\"],\n )\n predn[\"coco_masks\"] = coco_masks\n return predn\n\n def _process_batch(self, preds: Dict[str, torch.Tensor], batch: Dict[str, Any]) -> Dict[str, np.ndarray]:\n \"\"\"\n Compute correct prediction matrix for a batch based on bounding boxes and optional masks.\n\n Args:\n preds (Dict[str, torch.Tensor]): Dictionary containing predictions with keys like 'cls' and 'masks'.\n batch (Dict[str, Any]): Dictionary containing batch data with keys like 'cls' and 'masks'.\n\n Returns:\n (Dict[str, np.ndarray]): A dictionary containing correct prediction matrices including 'tp_m' for mask IoU.\n\n Notes:\n - If `masks` is True, the function computes IoU between predicted and ground truth masks.\n - If `overlap` is True and `masks` is True, overlapping masks are taken into account when computing IoU.\n\n Examples:\n >>> preds = {\"cls\": torch.tensor([1, 0]), \"masks\": torch.rand(2, 640, 640), \"bboxes\": torch.rand(2, 4)}\n >>> batch = {\"cls\": torch.tensor([1, 0]), \"masks\": torch.rand(2, 640, 640), \"bboxes\": torch.rand(2, 4)}\n >>> correct_preds = validator._process_batch(preds, batch)\n \"\"\"\n tp = super()._process_batch(preds, batch)\n gt_cls, gt_masks = batch[\"cls\"], batch[\"masks\"]\n if len(gt_cls) == 0 or len(preds[\"cls\"]) == 0:\n tp_m = np.zeros((len(preds[\"cls\"]), self.niou), dtype=bool)\n else:\n pred_masks = preds[\"masks\"]\n if self.args.overlap_mask:\n nl = len(gt_cls)\n index = torch.arange(nl, device=gt_masks.device).view(nl, 1, 1) + 1\n gt_masks = gt_masks.repeat(nl, 1, 1) # shape(1,640,640) -> (n,640,640)\n gt_masks = torch.where(gt_masks == index, 1.0, 0.0)\n if gt_masks.shape[1:] != pred_masks.shape[1:]:\n gt_masks = F.interpolate(gt_masks[None], pred_masks.shape[1:], mode=\"bilinear\", align_corners=False)[0]\n gt_masks = gt_masks.gt_(0.5)\n iou = mask_iou(gt_masks.view(gt_masks.shape[0], -1), pred_masks.view(pred_masks.shape[0], -1))\n tp_m = self.match_predictions(preds[\"cls\"], gt_cls, iou).cpu().numpy()\n tp.update({\"tp_m\": tp_m}) # update tp with mask IoU\n return tp\n\n def plot_predictions(self, batch: Dict[str, Any], preds: List[Dict[str, torch.Tensor]], ni: int) -> None:\n \"\"\"\n Plot batch predictions with masks and bounding boxes.\n\n Args:\n batch (Dict[str, Any]): Batch containing images and annotations.\n preds (List[Dict[str, torch.Tensor]]): List of predictions from the model.\n ni (int): Batch index.\n \"\"\"\n for p in preds:\n masks = p[\"masks\"]\n if masks.shape[0] > 50:\n LOGGER.warning(\"Limiting validation plots to first 50 items per image for speed...\")\n p[\"masks\"] = torch.as_tensor(masks[:50], dtype=torch.uint8).cpu()\n super().plot_predictions(batch, preds, ni, max_det=50) # plot bboxes\n\n def save_one_txt(self, predn: torch.Tensor, save_conf: bool, shape: Tuple[int, int], file: Path) -> None:\n \"\"\"\n Save YOLO detections to a txt file in normalized coordinates in a specific format.\n\n Args:\n predn (torch.Tensor): Predictions in the format (x1, y1, x2, y2, conf, class).\n save_conf (bool): Whether to save confidence scores.\n shape (Tuple[int, int]): Shape of the original image.\n file (Path): File path to save the detections.\n \"\"\"\n from ultralytics.engine.results import Results\n\n Results(\n np.zeros((shape[0], shape[1]), dtype=np.uint8),\n path=None,\n names=self.names,\n boxes=torch.cat([predn[\"bboxes\"], predn[\"conf\"].unsqueeze(-1), predn[\"cls\"].unsqueeze(-1)], dim=1),\n masks=torch.as_tensor(predn[\"masks\"], dtype=torch.uint8),\n ).save_txt(file, save_conf=save_conf)\n\n def pred_to_json(self, predn: torch.Tensor, filename: str) -> None:\n \"\"\"\n Save one JSON result for COCO evaluation.\n\n Args:\n predn (Dict[str, torch.Tensor]): Predictions containing bboxes, masks, confidence scores, and classes.\n filename (str): Image filename.\n\n Examples:\n >>> result = {\"image_id\": 42, \"category_id\": 18, \"bbox\": [258.15, 41.29, 348.26, 243.78], \"score\": 0.236}\n \"\"\"\n from faster_coco_eval.core.mask import encode # noqa\n\n def single_encode(x):\n \"\"\"Encode predicted masks as RLE and append results to jdict.\"\"\"\n rle = encode(np.asarray(x[:, :, None], order=\"F\", dtype=\"uint8\"))[0]\n rle[\"counts\"] = rle[\"counts\"].decode(\"utf-8\")\n return rle\n\n stem = Path(filename).stem\n image_id = int(stem) if stem.isnumeric() else stem\n box = ops.xyxy2xywh(predn[\"bboxes\"]) # xywh\n box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner\n pred_masks = np.transpose(predn[\"coco_masks\"], (2, 0, 1))\n with ThreadPool(NUM_THREADS) as pool:\n rles = pool.map(single_encode, pred_masks)\n for i, (b, s, c) in enumerate(zip(box.tolist(), predn[\"conf\"].tolist(), predn[\"cls\"].tolist())):\n self.jdict.append(\n {\n \"image_id\": image_id,\n \"category_id\": self.class_map[int(c)],\n \"bbox\": [round(x, 3) for x in b],\n \"score\": round(s, 5),\n \"segmentation\": rles[i],\n }\n )\n\n def eval_json(self, stats: Dict[str, Any]) -> Dict[str, Any]:\n \"\"\"Return COCO-style instance segmentation evaluation metrics.\"\"\"\n pred_json = self.save_dir / \"predictions.json\" # predictions\n anno_json = (\n self.data[\"path\"]\n / \"annotations\"\n / (\"instances_val2017.json\" if self.is_coco else f\"lvis_v1_{self.args.split}.json\")\n ) # annotations\n return super().coco_evaluate(stats, pred_json, anno_json, [\"bbox\", \"segm\"], suffix=[\"Box\", \"Mask\"])", "chunk_type": "class", "name": "SegmentationValidator", "file_path": "ultralytics\\ultralytics\\models\\yolo\\segment\\val.py", "start_line": 17, "end_line": 281, "start_col": 0, "end_col": 107, "parent_name": null, "docstring": "A class extending the DetectionValidator class for validation based on a segmentation model.\n\nThis validator handles the evaluation of segmentation models, processing both bounding box and mask predictions\nto compute metrics such as mAP for both detection and segmentation tasks.\n\nAttributes:\n plot_masks (list): List to store masks for plotting.\n process (callable): Function to process masks based on save_json and save_txt flags.\n args (namespace): Arguments for the validator.\n metrics (SegmentMetrics): Metrics calculator for segmentation tasks.\n stats (dict): Dictionary to store statistics during validation.\n\nExamples:\n >>> from ultralytics.models.yolo.segment import SegmentationValidator\n >>> args = dict(model=\"yolo11n-seg.pt\", data=\"coco8-seg.yaml\")\n >>> validator = SegmentationValidator(args=args)\n >>> validator()", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "multiprocessing.pool.ThreadPool", "pathlib.Path", "typing.Any", "typing.Dict", "typing.List", "typing.Tuple", "numpy", "torch", "torch.nn.functional", "ultralytics.models.yolo.detect.DetectionValidator", "ultralytics.utils.LOGGER", "ultralytics.utils.NUM_THREADS", "ultralytics.utils.ops", "ultralytics.utils.checks.check_requirements", "ultralytics.utils.metrics.SegmentMetrics", "ultralytics.utils.metrics.mask_iou", "ultralytics.engine.results.Results", "faster_coco_eval.core.mask.encode", "DetectionValidator" ], "chunk_id": "class_SegmentationValidator_6dc7ef2b" }, { "content": "from .predict import SegmentationPredictor", "chunk_type": "import", "name": "SegmentationPredictor", "file_path": "ultralytics\\ultralytics\\models\\yolo\\segment\\__init__.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 42, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_SegmentationPredictor_d11aa346" }, { "content": "from .train import SegmentationTrainer", "chunk_type": "import", "name": "SegmentationTrainer", "file_path": "ultralytics\\ultralytics\\models\\yolo\\segment\\__init__.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 38, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_SegmentationTrainer_0c904de2" }, { "content": "from .val import SegmentationValidator", "chunk_type": "import", "name": "SegmentationValidator", "file_path": "ultralytics\\ultralytics\\models\\yolo\\segment\\__init__.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 38, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_SegmentationValidator_2396777f" }, { "content": "__all__ = \"SegmentationPredictor\", \"SegmentationTrainer\", \"SegmentationValidator\"", "chunk_type": "variable", "name": "__all__", "file_path": "ultralytics\\ultralytics\\models\\yolo\\segment\\__init__.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 81, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable___all___02a783e6" }, { "content": "import itertools", "chunk_type": "import", "name": "itertools", "file_path": "ultralytics\\ultralytics\\models\\yolo\\world\\train.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 16, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_itertools_b4d994a3" }, { "content": "from pathlib import Path", "chunk_type": "import", "name": "Path", "file_path": "ultralytics\\ultralytics\\models\\yolo\\world\\train.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Path_6c0c7cc5" }, { "content": "from typing import Any, Dict, List, Optional", "chunk_type": "import", "name": "Any, Dict, List, Optional", "file_path": "ultralytics\\ultralytics\\models\\yolo\\world\\train.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 44, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Any, Dict, List, Optional_bdf62c65" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\ultralytics\\models\\yolo\\world\\train.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_7ad04b71" }, { "content": "from ultralytics.data import build_yolo_dataset", "chunk_type": "import", "name": "build_yolo_dataset", "file_path": "ultralytics\\ultralytics\\models\\yolo\\world\\train.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 47, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_build_yolo_dataset_fc6bae89" }, { "content": "from ultralytics.models.yolo.detect import DetectionTrainer", "chunk_type": "import", "name": "DetectionTrainer", "file_path": "ultralytics\\ultralytics\\models\\yolo\\world\\train.py", "start_line": 10, "end_line": 10, "start_col": 0, "end_col": 59, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_DetectionTrainer_91e8b8a8" }, { "content": "from ultralytics.nn.tasks import WorldModel", "chunk_type": "import", "name": "WorldModel", "file_path": "ultralytics\\ultralytics\\models\\yolo\\world\\train.py", "start_line": 11, "end_line": 11, "start_col": 0, "end_col": 43, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_WorldModel_ba3f49ae" }, { "content": "from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK", "chunk_type": "import", "name": "DEFAULT_CFG, LOGGER, RANK", "file_path": "ultralytics\\ultralytics\\models\\yolo\\world\\train.py", "start_line": 12, "end_line": 12, "start_col": 0, "end_col": 55, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_DEFAULT_CFG, LOGGER, RANK_52b50460" }, { "content": "from ultralytics.utils.torch_utils import de_parallel", "chunk_type": "import", "name": "de_parallel", "file_path": "ultralytics\\ultralytics\\models\\yolo\\world\\train.py", "start_line": 13, "end_line": 13, "start_col": 0, "end_col": 53, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_de_parallel_ed6180ba" }, { "content": "def on_pretrain_routine_end(trainer) -> None:\n \"\"\"Set up model classes and text encoder at the end of the pretrain routine.\"\"\"\n if RANK in {-1, 0}:\n # Set class names for evaluation\n names = [name.split(\"/\", 1)[0] for name in list(trainer.test_loader.dataset.data[\"names\"].values())]\n de_parallel(trainer.ema.ema).set_classes(names, cache_clip_model=False)", "chunk_type": "function", "name": "on_pretrain_routine_end", "file_path": "ultralytics\\ultralytics\\models\\yolo\\world\\train.py", "start_line": 16, "end_line": 21, "start_col": 0, "end_col": 79, "parent_name": null, "docstring": "Set up model classes and text encoder at the end of the pretrain routine.", "parameters": [ "trainer" ], "return_type": "None", "decorators": [], "complexity_score": 3, "dependencies": [ "itertools", "pathlib.Path", "typing.Any", "typing.Dict", "typing.List", "typing.Optional", "torch", "ultralytics.data.build_yolo_dataset", "ultralytics.models.yolo.detect.DetectionTrainer", "ultralytics.nn.tasks.WorldModel", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.LOGGER", "ultralytics.utils.RANK", "ultralytics.utils.torch_utils.de_parallel" ], "chunk_id": "function_on_pretrain_routine_end_a2916ba6" }, { "content": "class WorldTrainer(DetectionTrainer):\n \"\"\"\n A trainer class for fine-tuning YOLO World models on close-set datasets.\n\n This trainer extends the DetectionTrainer to support training YOLO World models, which combine visual and textual\n features for improved object detection and understanding. It handles text embedding generation and caching to\n accelerate training with multi-modal data.\n\n Attributes:\n text_embeddings (Dict[str, torch.Tensor] | None): Cached text embeddings for category names to accelerate\n training.\n model (WorldModel): The YOLO World model being trained.\n data (Dict[str, Any]): Dataset configuration containing class information.\n args (Any): Training arguments and configuration.\n\n Methods:\n get_model: Return WorldModel initialized with specified config and weights.\n build_dataset: Build YOLO Dataset for training or validation.\n set_text_embeddings: Set text embeddings for datasets to accelerate training.\n generate_text_embeddings: Generate text embeddings for a list of text samples.\n preprocess_batch: Preprocess a batch of images and text for YOLOWorld training.\n\n Examples:\n Initialize and train a YOLO World model\n >>> from ultralytics.models.yolo.world import WorldTrainer\n >>> args = dict(model=\"yolov8s-world.pt\", data=\"coco8.yaml\", epochs=3)\n >>> trainer = WorldTrainer(overrides=args)\n >>> trainer.train()\n \"\"\"\n\n def __init__(self, cfg=DEFAULT_CFG, overrides: Optional[Dict[str, Any]] = None, _callbacks=None):\n \"\"\"\n Initialize a WorldTrainer object with given arguments.\n\n Args:\n cfg (Dict[str, Any]): Configuration for the trainer.\n overrides (Dict[str, Any], optional): Configuration overrides.\n _callbacks (List[Any], optional): List of callback functions.\n \"\"\"\n if overrides is None:\n overrides = {}\n super().__init__(cfg, overrides, _callbacks)\n self.text_embeddings = None\n\n def get_model(self, cfg=None, weights: Optional[str] = None, verbose: bool = True) -> WorldModel:\n \"\"\"\n Return WorldModel initialized with specified config and weights.\n\n Args:\n cfg (Dict[str, Any] | str, optional): Model configuration.\n weights (str, optional): Path to pretrained weights.\n verbose (bool): Whether to display model info.\n\n Returns:\n (WorldModel): Initialized WorldModel.\n \"\"\"\n # NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`.\n # NOTE: Following the official config, nc hard-coded to 80 for now.\n model = WorldModel(\n cfg[\"yaml_file\"] if isinstance(cfg, dict) else cfg,\n ch=self.data[\"channels\"],\n nc=min(self.data[\"nc\"], 80),\n verbose=verbose and RANK == -1,\n )\n if weights:\n model.load(weights)\n self.add_callback(\"on_pretrain_routine_end\", on_pretrain_routine_end)\n\n return model\n\n def build_dataset(self, img_path: str, mode: str = \"train\", batch: Optional[int] = None):\n \"\"\"\n Build YOLO Dataset for training or validation.\n\n Args:\n img_path (str): Path to the folder containing images.\n mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.\n batch (int, optional): Size of batches, this is for `rect`.\n\n Returns:\n (Any): YOLO dataset configured for training or validation.\n \"\"\"\n gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)\n dataset = build_yolo_dataset(\n self.args, img_path, batch, self.data, mode=mode, rect=mode == \"val\", stride=gs, multi_modal=mode == \"train\"\n )\n if mode == \"train\":\n self.set_text_embeddings([dataset], batch) # cache text embeddings to accelerate training\n return dataset\n\n def set_text_embeddings(self, datasets: List[Any], batch: Optional[int]) -> None:\n \"\"\"\n Set text embeddings for datasets to accelerate training by caching category names.\n\n This method collects unique category names from all datasets, then generates and caches text embeddings\n for these categories to improve training efficiency.\n\n Args:\n datasets (List[Any]): List of datasets from which to extract category names.\n batch (int | None): Batch size used for processing.\n\n Notes:\n This method collects category names from datasets that have the 'category_names' attribute,\n then uses the first dataset's image path to determine where to cache the generated text embeddings.\n \"\"\"\n text_embeddings = {}\n for dataset in datasets:\n if not hasattr(dataset, \"category_names\"):\n continue\n text_embeddings.update(\n self.generate_text_embeddings(\n list(dataset.category_names), batch, cache_dir=Path(dataset.img_path).parent\n )\n )\n self.text_embeddings = text_embeddings\n\n def generate_text_embeddings(self, texts: List[str], batch: int, cache_dir: Path) -> Dict[str, torch.Tensor]:\n \"\"\"\n Generate text embeddings for a list of text samples.\n\n Args:\n texts (List[str]): List of text samples to encode.\n batch (int): Batch size for processing.\n cache_dir (Path): Directory to save/load cached embeddings.\n\n Returns:\n (Dict[str, torch.Tensor]): Dictionary mapping text samples to their embeddings.\n \"\"\"\n model = \"clip:ViT-B/32\"\n cache_path = cache_dir / f\"text_embeddings_{model.replace(':', '_').replace('/', '_')}.pt\"\n if cache_path.exists():\n LOGGER.info(f\"Reading existed cache from '{cache_path}'\")\n txt_map = torch.load(cache_path, map_location=self.device)\n if sorted(txt_map.keys()) == sorted(texts):\n return txt_map\n LOGGER.info(f\"Caching text embeddings to '{cache_path}'\")\n assert self.model is not None\n txt_feats = de_parallel(self.model).get_text_pe(texts, batch, cache_clip_model=False)\n txt_map = dict(zip(texts, txt_feats.squeeze(0)))\n torch.save(txt_map, cache_path)\n return txt_map\n\n def preprocess_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]:\n \"\"\"Preprocess a batch of images and text for YOLOWorld training.\"\"\"\n batch = DetectionTrainer.preprocess_batch(self, batch)\n\n # Add text features\n texts = list(itertools.chain(*batch[\"texts\"]))\n txt_feats = torch.stack([self.text_embeddings[text] for text in texts]).to(self.device)\n txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True)\n batch[\"txt_feats\"] = txt_feats.reshape(len(batch[\"texts\"]), -1, txt_feats.shape[-1])\n return batch", "chunk_type": "class", "name": "WorldTrainer", "file_path": "ultralytics\\ultralytics\\models\\yolo\\world\\train.py", "start_line": 24, "end_line": 175, "start_col": 0, "end_col": 20, "parent_name": null, "docstring": "A trainer class for fine-tuning YOLO World models on close-set datasets.\n\nThis trainer extends the DetectionTrainer to support training YOLO World models, which combine visual and textual\nfeatures for improved object detection and understanding. It handles text embedding generation and caching to\naccelerate training with multi-modal data.\n\nAttributes:\n text_embeddings (Dict[str, torch.Tensor] | None): Cached text embeddings for category names to accelerate\n training.\n model (WorldModel): The YOLO World model being trained.\n data (Dict[str, Any]): Dataset configuration containing class information.\n args (Any): Training arguments and configuration.\n\nMethods:\n get_model: Return WorldModel initialized with specified config and weights.\n build_dataset: Build YOLO Dataset for training or validation.\n set_text_embeddings: Set text embeddings for datasets to accelerate training.\n generate_text_embeddings: Generate text embeddings for a list of text samples.\n preprocess_batch: Preprocess a batch of images and text for YOLOWorld training.\n\nExamples:\n Initialize and train a YOLO World model\n >>> from ultralytics.models.yolo.world import WorldTrainer\n >>> args = dict(model=\"yolov8s-world.pt\", data=\"coco8.yaml\", epochs=3)\n >>> trainer = WorldTrainer(overrides=args)\n >>> trainer.train()", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "itertools", "pathlib.Path", "typing.Any", "typing.Dict", "typing.List", "typing.Optional", "torch", "ultralytics.data.build_yolo_dataset", "ultralytics.models.yolo.detect.DetectionTrainer", "ultralytics.nn.tasks.WorldModel", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.LOGGER", "ultralytics.utils.RANK", "ultralytics.utils.torch_utils.de_parallel", "DetectionTrainer" ], "chunk_id": "class_WorldTrainer_dced3b02" }, { "content": "from pathlib import Path", "chunk_type": "import", "name": "Path", "file_path": "ultralytics\\ultralytics\\models\\yolo\\world\\train_world.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Path_cebd899c" }, { "content": "from ultralytics.data import YOLOConcatDataset, build_grounding, build_yolo_dataset", "chunk_type": "import", "name": "YOLOConcatDataset, build_grounding, build_yolo_dataset", "file_path": "ultralytics\\ultralytics\\models\\yolo\\world\\train_world.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 83, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_YOLOConcatDataset, build_grounding, build_yolo_dataset_cf0238a8" }, { "content": "from ultralytics.data.utils import check_det_dataset", "chunk_type": "import", "name": "check_det_dataset", "file_path": "ultralytics\\ultralytics\\models\\yolo\\world\\train_world.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 52, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_check_det_dataset_581e0267" }, { "content": "from ultralytics.models.yolo.world import WorldTrainer", "chunk_type": "import", "name": "WorldTrainer", "file_path": "ultralytics\\ultralytics\\models\\yolo\\world\\train_world.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 54, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_WorldTrainer_afd3fd51" }, { "content": "from ultralytics.utils import DATASETS_DIR, DEFAULT_CFG, LOGGER", "chunk_type": "import", "name": "DATASETS_DIR, DEFAULT_CFG, LOGGER", "file_path": "ultralytics\\ultralytics\\models\\yolo\\world\\train_world.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 63, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_DATASETS_DIR, DEFAULT_CFG, LOGGER_29233daf" }, { "content": "from ultralytics.utils.torch_utils import de_parallel", "chunk_type": "import", "name": "de_parallel", "file_path": "ultralytics\\ultralytics\\models\\yolo\\world\\train_world.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 53, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_de_parallel_c46092b2" }, { "content": "class WorldTrainerFromScratch(WorldTrainer):\n \"\"\"\n A class extending the WorldTrainer for training a world model from scratch on open-set datasets.\n\n This trainer specializes in handling mixed datasets including both object detection and grounding datasets,\n supporting training YOLO-World models with combined vision-language capabilities.\n\n Attributes:\n cfg (dict): Configuration dictionary with default parameters for model training.\n overrides (dict): Dictionary of parameter overrides to customize the configuration.\n _callbacks (list): List of callback functions to be executed during different stages of training.\n data (dict): Final processed data configuration containing train/val paths and metadata.\n training_data (dict): Dictionary mapping training dataset paths to their configurations.\n\n Methods:\n build_dataset: Build YOLO Dataset for training or validation with mixed dataset support.\n get_dataset: Get train and validation paths from data dictionary.\n plot_training_labels: Skip label plotting for YOLO-World training.\n final_eval: Perform final evaluation and validation for the YOLO-World model.\n\n Examples:\n >>> from ultralytics.models.yolo.world.train_world import WorldTrainerFromScratch\n >>> from ultralytics import YOLOWorld\n >>> data = dict(\n ... train=dict(\n ... yolo_data=[\"Objects365.yaml\"],\n ... grounding_data=[\n ... dict(\n ... img_path=\"flickr30k/images\",\n ... json_file=\"flickr30k/final_flickr_separateGT_train.json\",\n ... ),\n ... dict(\n ... img_path=\"GQA/images\",\n ... json_file=\"GQA/final_mixed_train_no_coco.json\",\n ... ),\n ... ],\n ... ),\n ... val=dict(yolo_data=[\"lvis.yaml\"]),\n ... )\n >>> model = YOLOWorld(\"yolov8s-worldv2.yaml\")\n >>> model.train(data=data, trainer=WorldTrainerFromScratch)\n \"\"\"\n\n def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):\n \"\"\"\n Initialize a WorldTrainerFromScratch object.\n\n This initializes a trainer for YOLO-World models from scratch, supporting mixed datasets including both\n object detection and grounding datasets for vision-language capabilities.\n\n Args:\n cfg (dict): Configuration dictionary with default parameters for model training.\n overrides (dict, optional): Dictionary of parameter overrides to customize the configuration.\n _callbacks (list, optional): List of callback functions to be executed during different stages of training.\n\n Examples:\n >>> from ultralytics.models.yolo.world.train_world import WorldTrainerFromScratch\n >>> from ultralytics import YOLOWorld\n >>> data = dict(\n ... train=dict(\n ... yolo_data=[\"Objects365.yaml\"],\n ... grounding_data=[\n ... dict(\n ... img_path=\"flickr30k/images\",\n ... json_file=\"flickr30k/final_flickr_separateGT_train.json\",\n ... ),\n ... ],\n ... ),\n ... val=dict(yolo_data=[\"lvis.yaml\"]),\n ... )\n >>> model = YOLOWorld(\"yolov8s-worldv2.yaml\")\n >>> model.train(data=data, trainer=WorldTrainerFromScratch)\n \"\"\"\n if overrides is None:\n overrides = {}\n super().__init__(cfg, overrides, _callbacks)\n\n def build_dataset(self, img_path, mode=\"train\", batch=None):\n \"\"\"\n Build YOLO Dataset for training or validation.\n\n This method constructs appropriate datasets based on the mode and input paths, handling both\n standard YOLO datasets and grounding datasets with different formats.\n\n Args:\n img_path (List[str] | str): Path to the folder containing images or list of paths.\n mode (str): 'train' mode or 'val' mode, allowing customized augmentations for each mode.\n batch (int, optional): Size of batches, used for rectangular training/validation.\n\n Returns:\n (YOLOConcatDataset | Dataset): The constructed dataset for training or validation.\n \"\"\"\n gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)\n if mode != \"train\":\n return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=False, stride=gs)\n datasets = [\n build_yolo_dataset(self.args, im_path, batch, self.training_data[im_path], stride=gs, multi_modal=True)\n if isinstance(im_path, str)\n else build_grounding(\n # assign `nc` from validation set to max number of text samples for training consistency\n self.args,\n im_path[\"img_path\"],\n im_path[\"json_file\"],\n batch,\n stride=gs,\n max_samples=self.data[\"nc\"],\n )\n for im_path in img_path\n ]\n self.set_text_embeddings(datasets, batch) # cache text embeddings to accelerate training\n return YOLOConcatDataset(datasets) if len(datasets) > 1 else datasets[0]\n\n def get_dataset(self):\n \"\"\"\n Get train and validation paths from data dictionary.\n\n Processes the data configuration to extract paths for training and validation datasets,\n handling both YOLO detection datasets and grounding datasets.\n\n Returns:\n train_path (str): Train dataset path.\n val_path (str): Validation dataset path.\n\n Raises:\n AssertionError: If train or validation datasets are not found, or if validation has multiple datasets.\n \"\"\"\n final_data = {}\n data_yaml = self.args.data\n assert data_yaml.get(\"train\", False), \"train dataset not found\" # object365.yaml\n assert data_yaml.get(\"val\", False), \"validation dataset not found\" # lvis.yaml\n data = {k: [check_det_dataset(d) for d in v.get(\"yolo_data\", [])] for k, v in data_yaml.items()}\n assert len(data[\"val\"]) == 1, f\"Only support validating on 1 dataset for now, but got {len(data['val'])}.\"\n val_split = \"minival\" if \"lvis\" in data[\"val\"][0][\"val\"] else \"val\"\n for d in data[\"val\"]:\n if d.get(\"minival\") is None: # for lvis dataset\n continue\n d[\"minival\"] = str(d[\"path\"] / d[\"minival\"])\n for s in {\"train\", \"val\"}:\n final_data[s] = [d[\"train\" if s == \"train\" else val_split] for d in data[s]]\n # save grounding data if there's one\n grounding_data = data_yaml[s].get(\"grounding_data\")\n if grounding_data is None:\n continue\n grounding_data = grounding_data if isinstance(grounding_data, list) else [grounding_data]\n for g in grounding_data:\n assert isinstance(g, dict), f\"Grounding data should be provided in dict format, but got {type(g)}\"\n for k in {\"img_path\", \"json_file\"}:\n path = Path(g[k])\n if not path.exists() and not path.is_absolute():\n g[k] = str((DATASETS_DIR / g[k]).resolve()) # path relative to DATASETS_DIR\n final_data[s] += grounding_data\n # assign the first val dataset as currently only one validation set is supported\n data[\"val\"] = data[\"val\"][0]\n final_data[\"val\"] = final_data[\"val\"][0]\n # NOTE: to make training work properly, set `nc` and `names`\n final_data[\"nc\"] = data[\"val\"][\"nc\"]\n final_data[\"names\"] = data[\"val\"][\"names\"]\n # NOTE: add path with lvis path\n final_data[\"path\"] = data[\"val\"][\"path\"]\n final_data[\"channels\"] = data[\"val\"][\"channels\"]\n self.data = final_data\n if self.args.single_cls: # consistent with base trainer\n LOGGER.info(\"Overriding class names with single class.\")\n self.data[\"names\"] = {0: \"object\"}\n self.data[\"nc\"] = 1\n self.training_data = {}\n for d in data[\"train\"]:\n if self.args.single_cls:\n d[\"names\"] = {0: \"object\"}\n d[\"nc\"] = 1\n self.training_data[d[\"train\"]] = d\n return final_data\n\n def plot_training_labels(self):\n \"\"\"Skip label plotting for YOLO-World training.\"\"\"\n pass\n\n def final_eval(self):\n \"\"\"\n Perform final evaluation and validation for the YOLO-World model.\n\n Configures the validator with appropriate dataset and split information before running evaluation.\n\n Returns:\n (dict): Dictionary containing evaluation metrics and results.\n \"\"\"\n val = self.args.data[\"val\"][\"yolo_data\"][0]\n self.validator.args.data = val\n self.validator.args.split = \"minival\" if isinstance(val, str) and \"lvis\" in val else \"val\"\n return super().final_eval()", "chunk_type": "class", "name": "WorldTrainerFromScratch", "file_path": "ultralytics\\ultralytics\\models\\yolo\\world\\train_world.py", "start_line": 12, "end_line": 201, "start_col": 0, "end_col": 35, "parent_name": null, "docstring": "A class extending the WorldTrainer for training a world model from scratch on open-set datasets.\n\nThis trainer specializes in handling mixed datasets including both object detection and grounding datasets,\nsupporting training YOLO-World models with combined vision-language capabilities.\n\nAttributes:\n cfg (dict): Configuration dictionary with default parameters for model training.\n overrides (dict): Dictionary of parameter overrides to customize the configuration.\n _callbacks (list): List of callback functions to be executed during different stages of training.\n data (dict): Final processed data configuration containing train/val paths and metadata.\n training_data (dict): Dictionary mapping training dataset paths to their configurations.\n\nMethods:\n build_dataset: Build YOLO Dataset for training or validation with mixed dataset support.\n get_dataset: Get train and validation paths from data dictionary.\n plot_training_labels: Skip label plotting for YOLO-World training.\n final_eval: Perform final evaluation and validation for the YOLO-World model.\n\nExamples:\n >>> from ultralytics.models.yolo.world.train_world import WorldTrainerFromScratch\n >>> from ultralytics import YOLOWorld\n >>> data = dict(\n ... train=dict(\n ... yolo_data=[\"Objects365.yaml\"],\n ... grounding_data=[\n ... dict(\n ... img_path=\"flickr30k/images\",\n ... json_file=\"flickr30k/final_flickr_separateGT_train.json\",\n ... ),\n ... dict(\n ... img_path=\"GQA/images\",\n ... json_file=\"GQA/final_mixed_train_no_coco.json\",\n ... ),\n ... ],\n ... ),\n ... val=dict(yolo_data=[\"lvis.yaml\"]),\n ... )\n >>> model = YOLOWorld(\"yolov8s-worldv2.yaml\")\n >>> model.train(data=data, trainer=WorldTrainerFromScratch)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "pathlib.Path", "ultralytics.data.YOLOConcatDataset", "ultralytics.data.build_grounding", "ultralytics.data.build_yolo_dataset", "ultralytics.data.utils.check_det_dataset", "ultralytics.models.yolo.world.WorldTrainer", "ultralytics.utils.DATASETS_DIR", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.LOGGER", "ultralytics.utils.torch_utils.de_parallel", "WorldTrainer" ], "chunk_id": "class_WorldTrainerFromScratch_777a0ffc" }, { "content": "from .train import WorldTrainer", "chunk_type": "import", "name": "WorldTrainer", "file_path": "ultralytics\\ultralytics\\models\\yolo\\world\\__init__.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 31, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_WorldTrainer_2412de85" }, { "content": "__all__ = [\"WorldTrainer\"]", "chunk_type": "variable", "name": "__all__", "file_path": "ultralytics\\ultralytics\\models\\yolo\\world\\__init__.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 26, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable___all___aac26038" }, { "content": "import numpy as np", "chunk_type": "import", "name": "numpy", "file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\predict.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_numpy_492ba178" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\predict.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_cae74323" }, { "content": "from ultralytics.data.augment import LoadVisualPrompt", "chunk_type": "import", "name": "LoadVisualPrompt", "file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\predict.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 53, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_LoadVisualPrompt_636b68f2" }, { "content": "from ultralytics.models.yolo.detect import DetectionPredictor", "chunk_type": "import", "name": "DetectionPredictor", "file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\predict.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 61, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_DetectionPredictor_52754e1d" }, { "content": "from ultralytics.models.yolo.segment import SegmentationPredictor", "chunk_type": "import", "name": "SegmentationPredictor", "file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\predict.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 65, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_SegmentationPredictor_504c32b1" }, { "content": "class YOLOEVPDetectPredictor(DetectionPredictor):\n \"\"\"\n A mixin class for YOLO-EVP (Enhanced Visual Prompting) predictors.\n\n This mixin provides common functionality for YOLO models that use visual prompting, including\n model setup, prompt handling, and preprocessing transformations.\n\n Attributes:\n model (torch.nn.Module): The YOLO model for inference.\n device (torch.device): Device to run the model on (CPU or CUDA).\n prompts (dict | torch.Tensor): Visual prompts containing class indices and bounding boxes or masks.\n\n Methods:\n setup_model: Initialize the YOLO model and set it to evaluation mode.\n set_prompts: Set the visual prompts for the model.\n pre_transform: Preprocess images and prompts before inference.\n inference: Run inference with visual prompts.\n get_vpe: Process source to get visual prompt embeddings.\n \"\"\"\n\n def setup_model(self, model, verbose: bool = True):\n \"\"\"\n Set up the model for prediction.\n\n Args:\n model (torch.nn.Module): Model to load or use.\n verbose (bool, optional): If True, provides detailed logging.\n \"\"\"\n super().setup_model(model, verbose=verbose)\n self.done_warmup = True\n\n def set_prompts(self, prompts):\n \"\"\"\n Set the visual prompts for the model.\n\n Args:\n prompts (dict): Dictionary containing class indices and bounding boxes or masks.\n Must include a 'cls' key with class indices.\n \"\"\"\n self.prompts = prompts\n\n def pre_transform(self, im):\n \"\"\"\n Preprocess images and prompts before inference.\n\n This method applies letterboxing to the input image and transforms the visual prompts\n (bounding boxes or masks) accordingly.\n\n Args:\n im (list): List containing a single input image.\n\n Returns:\n (list): Preprocessed image ready for model inference.\n\n Raises:\n ValueError: If neither valid bounding boxes nor masks are provided in the prompts.\n \"\"\"\n img = super().pre_transform(im)\n bboxes = self.prompts.pop(\"bboxes\", None)\n masks = self.prompts.pop(\"masks\", None)\n category = self.prompts[\"cls\"]\n if len(img) == 1:\n visuals = self._process_single_image(img[0].shape[:2], im[0].shape[:2], category, bboxes, masks)\n self.prompts = visuals.unsqueeze(0).to(self.device) # (1, N, H, W)\n else:\n # NOTE: only supports bboxes as prompts for now\n assert bboxes is not None, f\"Expected bboxes, but got {bboxes}!\"\n # NOTE: needs List[np.ndarray]\n assert isinstance(bboxes, list) and all(isinstance(b, np.ndarray) for b in bboxes), (\n f\"Expected List[np.ndarray], but got {bboxes}!\"\n )\n assert isinstance(category, list) and all(isinstance(b, np.ndarray) for b in category), (\n f\"Expected List[np.ndarray], but got {category}!\"\n )\n assert len(im) == len(category) == len(bboxes), (\n f\"Expected same length for all inputs, but got {len(im)}vs{len(category)}vs{len(bboxes)}!\"\n )\n visuals = [\n self._process_single_image(img[i].shape[:2], im[i].shape[:2], category[i], bboxes[i])\n for i in range(len(img))\n ]\n self.prompts = torch.nn.utils.rnn.pad_sequence(visuals, batch_first=True).to(self.device)\n\n return img\n\n def _process_single_image(self, dst_shape, src_shape, category, bboxes=None, masks=None):\n \"\"\"\n Process a single image by resizing bounding boxes or masks and generating visuals.\n\n Args:\n dst_shape (tuple): The target shape (height, width) of the image.\n src_shape (tuple): The original shape (height, width) of the image.\n category (str): The category of the image for visual prompts.\n bboxes (list | np.ndarray, optional): A list of bounding boxes in the format [x1, y1, x2, y2].\n masks (np.ndarray, optional): A list of masks corresponding to the image.\n\n Returns:\n (torch.Tensor): The processed visuals for the image.\n\n Raises:\n ValueError: If neither `bboxes` nor `masks` are provided.\n \"\"\"\n if bboxes is not None and len(bboxes):\n bboxes = np.array(bboxes, dtype=np.float32)\n if bboxes.ndim == 1:\n bboxes = bboxes[None, :]\n # Calculate scaling factor and adjust bounding boxes\n gain = min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1]) # gain = old / new\n bboxes *= gain\n bboxes[..., 0::2] += round((dst_shape[1] - src_shape[1] * gain) / 2 - 0.1)\n bboxes[..., 1::2] += round((dst_shape[0] - src_shape[0] * gain) / 2 - 0.1)\n elif masks is not None:\n # Resize and process masks\n resized_masks = super().pre_transform(masks)\n masks = np.stack(resized_masks) # (N, H, W)\n masks[masks == 114] = 0 # Reset padding values to 0\n else:\n raise ValueError(\"Please provide valid bboxes or masks\")\n\n # Generate visuals using the visual prompt loader\n return LoadVisualPrompt().get_visuals(category, dst_shape, bboxes, masks)\n\n def inference(self, im, *args, **kwargs):\n \"\"\"\n Run inference with visual prompts.\n\n Args:\n im (torch.Tensor): Input image tensor.\n *args (Any): Variable length argument list.\n **kwargs (Any): Arbitrary keyword arguments.\n\n Returns:\n (torch.Tensor): Model prediction results.\n \"\"\"\n return super().inference(im, vpe=self.prompts, *args, **kwargs)\n\n def get_vpe(self, source):\n \"\"\"\n Process the source to get the visual prompt embeddings (VPE).\n\n Args:\n source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | List | Tuple): The source\n of the image to make predictions on. Accepts various types including file paths, URLs, PIL\n images, numpy arrays, and torch tensors.\n\n Returns:\n (torch.Tensor): The visual prompt embeddings (VPE) from the model.\n \"\"\"\n self.setup_source(source)\n assert len(self.dataset) == 1, \"get_vpe only supports one image!\"\n for _, im0s, _ in self.dataset:\n im = self.preprocess(im0s)\n return self.model(im, vpe=self.prompts, return_vpe=True)", "chunk_type": "class", "name": "YOLOEVPDetectPredictor", "file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\predict.py", "start_line": 11, "end_line": 163, "start_col": 0, "end_col": 68, "parent_name": null, "docstring": "A mixin class for YOLO-EVP (Enhanced Visual Prompting) predictors.\n\nThis mixin provides common functionality for YOLO models that use visual prompting, including\nmodel setup, prompt handling, and preprocessing transformations.\n\nAttributes:\n model (torch.nn.Module): The YOLO model for inference.\n device (torch.device): Device to run the model on (CPU or CUDA).\n prompts (dict | torch.Tensor): Visual prompts containing class indices and bounding boxes or masks.\n\nMethods:\n setup_model: Initialize the YOLO model and set it to evaluation mode.\n set_prompts: Set the visual prompts for the model.\n pre_transform: Preprocess images and prompts before inference.\n inference: Run inference with visual prompts.\n get_vpe: Process source to get visual prompt embeddings.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "numpy", "torch", "ultralytics.data.augment.LoadVisualPrompt", "ultralytics.models.yolo.detect.DetectionPredictor", "ultralytics.models.yolo.segment.SegmentationPredictor", "DetectionPredictor" ], "chunk_id": "class_YOLOEVPDetectPredictor_e62fb0e7" }, { "content": "class YOLOEVPSegPredictor(YOLOEVPDetectPredictor, SegmentationPredictor):\n \"\"\"Predictor for YOLO-EVP segmentation tasks combining detection and segmentation capabilities.\"\"\"\n\n pass", "chunk_type": "class", "name": "YOLOEVPSegPredictor", "file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\predict.py", "start_line": 166, "end_line": 169, "start_col": 0, "end_col": 8, "parent_name": null, "docstring": "Predictor for YOLO-EVP segmentation tasks combining detection and segmentation capabilities.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "numpy", "torch", "ultralytics.data.augment.LoadVisualPrompt", "ultralytics.models.yolo.detect.DetectionPredictor", "ultralytics.models.yolo.segment.SegmentationPredictor", "YOLOEVPDetectPredictor", "SegmentationPredictor" ], "chunk_id": "class_YOLOEVPSegPredictor_169e6011" }, { "content": "import itertools", "chunk_type": "import", "name": "itertools", "file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\train.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 16, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_itertools_83951bc0" }, { "content": "from copy import copy, deepcopy", "chunk_type": "import", "name": "copy, deepcopy", "file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\train.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 31, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_copy, deepcopy_a95ce580" }, { "content": "from pathlib import Path", "chunk_type": "import", "name": "Path", "file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\train.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Path_a1b0f115" }, { "content": "from typing import Dict, List, Optional, Union", "chunk_type": "import", "name": "Dict, List, Optional, Union", "file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\train.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 46, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Dict, List, Optional, Union_a4a23fb8" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\train.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_6310badd" }, { "content": "from ultralytics.data import YOLOConcatDataset, build_yolo_dataset", "chunk_type": "import", "name": "YOLOConcatDataset, build_yolo_dataset", "file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\train.py", "start_line": 10, "end_line": 10, "start_col": 0, "end_col": 66, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_YOLOConcatDataset, build_yolo_dataset_80e4499d" }, { "content": "from ultralytics.data.augment import LoadVisualPrompt", "chunk_type": "import", "name": "LoadVisualPrompt", "file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\train.py", "start_line": 11, "end_line": 11, "start_col": 0, "end_col": 53, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_LoadVisualPrompt_6cfad980" }, { "content": "from ultralytics.models.yolo.detect import DetectionTrainer, DetectionValidator", "chunk_type": "import", "name": "DetectionTrainer, DetectionValidator", "file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\train.py", "start_line": 12, "end_line": 12, "start_col": 0, "end_col": 79, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_DetectionTrainer, DetectionValidator_cd51ace3" }, { "content": "from ultralytics.nn.tasks import YOLOEModel", "chunk_type": "import", "name": "YOLOEModel", "file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\train.py", "start_line": 13, "end_line": 13, "start_col": 0, "end_col": 43, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_YOLOEModel_feeb20fc" }, { "content": "from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK", "chunk_type": "import", "name": "DEFAULT_CFG, LOGGER, RANK", "file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\train.py", "start_line": 14, "end_line": 14, "start_col": 0, "end_col": 55, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_DEFAULT_CFG, LOGGER, RANK_f06b862c" }, { "content": "from ultralytics.utils.torch_utils import de_parallel", "chunk_type": "import", "name": "de_parallel", "file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\train.py", "start_line": 15, "end_line": 15, "start_col": 0, "end_col": 53, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_de_parallel_f537c039" }, { "content": "from ..world.train_world import WorldTrainerFromScratch", "chunk_type": "import", "name": "WorldTrainerFromScratch", "file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\train.py", "start_line": 17, "end_line": 17, "start_col": 0, "end_col": 55, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_WorldTrainerFromScratch_d87e7e7f" }, { "content": "from .val import YOLOEDetectValidator", "chunk_type": "import", "name": "YOLOEDetectValidator", "file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\train.py", "start_line": 18, "end_line": 18, "start_col": 0, "end_col": 37, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_YOLOEDetectValidator_390eccc6" }, { "content": "class YOLOETrainer(DetectionTrainer):\n \"\"\"\n A trainer class for YOLOE object detection models.\n\n This class extends DetectionTrainer to provide specialized training functionality for YOLOE models,\n including custom model initialization, validation, and dataset building with multi-modal support.\n\n Attributes:\n loss_names (tuple): Names of loss components used during training.\n\n Methods:\n get_model: Initialize and return a YOLOEModel with specified configuration.\n get_validator: Return a YOLOEDetectValidator for model validation.\n build_dataset: Build YOLO dataset with multi-modal support for training.\n \"\"\"\n\n def __init__(self, cfg=DEFAULT_CFG, overrides: Optional[Dict] = None, _callbacks=None):\n \"\"\"\n Initialize the YOLOE Trainer with specified configurations.\n\n This method sets up the YOLOE trainer with the provided configuration and overrides, initializing\n the training environment, model, and callbacks for YOLOE object detection training.\n\n Args:\n cfg (dict): Configuration dictionary with default training settings from DEFAULT_CFG.\n overrides (dict, optional): Dictionary of parameter overrides for the default configuration.\n _callbacks (list, optional): List of callback functions to be applied during training.\n \"\"\"\n if overrides is None:\n overrides = {}\n overrides[\"overlap_mask\"] = False\n super().__init__(cfg, overrides, _callbacks)\n\n def get_model(self, cfg=None, weights=None, verbose: bool = True):\n \"\"\"\n Return a YOLOEModel initialized with the specified configuration and weights.\n\n Args:\n cfg (dict | str, optional): Model configuration. Can be a dictionary containing a 'yaml_file' key,\n a direct path to a YAML file, or None to use default configuration.\n weights (str | Path, optional): Path to pretrained weights file to load into the model.\n verbose (bool): Whether to display model information during initialization.\n\n Returns:\n (YOLOEModel): The initialized YOLOE model.\n\n Notes:\n - The number of classes (nc) is hard-coded to a maximum of 80 following the official configuration.\n - The nc parameter here represents the maximum number of different text samples in one image,\n rather than the actual number of classes.\n \"\"\"\n # NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`.\n # NOTE: Following the official config, nc hard-coded to 80 for now.\n model = YOLOEModel(\n cfg[\"yaml_file\"] if isinstance(cfg, dict) else cfg,\n ch=self.data[\"channels\"],\n nc=min(self.data[\"nc\"], 80),\n verbose=verbose and RANK == -1,\n )\n if weights:\n model.load(weights)\n\n return model\n\n def get_validator(self):\n \"\"\"Return a YOLOEDetectValidator for YOLOE model validation.\"\"\"\n self.loss_names = \"box\", \"cls\", \"dfl\"\n return YOLOEDetectValidator(\n self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks\n )\n\n def build_dataset(self, img_path: str, mode: str = \"train\", batch: Optional[int] = None):\n \"\"\"\n Build YOLO Dataset.\n\n Args:\n img_path (str): Path to the folder containing images.\n mode (str): 'train' mode or 'val' mode, users are able to customize different augmentations for each mode.\n batch (int, optional): Size of batches, this is for rectangular training.\n\n Returns:\n (Dataset): YOLO dataset configured for training or validation.\n \"\"\"\n gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)\n return build_yolo_dataset(\n self.args, img_path, batch, self.data, mode=mode, rect=mode == \"val\", stride=gs, multi_modal=mode == \"train\"\n )", "chunk_type": "class", "name": "YOLOETrainer", "file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\train.py", "start_line": 21, "end_line": 107, "start_col": 0, "end_col": 9, "parent_name": null, "docstring": "A trainer class for YOLOE object detection models.\n\nThis class extends DetectionTrainer to provide specialized training functionality for YOLOE models,\nincluding custom model initialization, validation, and dataset building with multi-modal support.\n\nAttributes:\n loss_names (tuple): Names of loss components used during training.\n\nMethods:\n get_model: Initialize and return a YOLOEModel with specified configuration.\n get_validator: Return a YOLOEDetectValidator for model validation.\n build_dataset: Build YOLO dataset with multi-modal support for training.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "itertools", "copy.copy", "copy.deepcopy", "pathlib.Path", "typing.Dict", "typing.List", "typing.Optional", "typing.Union", "torch", "ultralytics.data.YOLOConcatDataset", "ultralytics.data.build_yolo_dataset", "ultralytics.data.augment.LoadVisualPrompt", "ultralytics.models.yolo.detect.DetectionTrainer", "ultralytics.models.yolo.detect.DetectionValidator", "ultralytics.nn.tasks.YOLOEModel", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.LOGGER", "ultralytics.utils.RANK", "ultralytics.utils.torch_utils.de_parallel", "world.train_world.WorldTrainerFromScratch", "val.YOLOEDetectValidator", "DetectionTrainer" ], "chunk_id": "class_YOLOETrainer_876d5b58" }, { "content": "class YOLOEPETrainer(DetectionTrainer):\n \"\"\"\n Fine-tune YOLOE model using linear probing approach.\n\n This trainer freezes most model layers and only trains specific projection layers for efficient\n fine-tuning on new datasets while preserving pretrained features.\n\n Methods:\n get_model: Initialize YOLOEModel with frozen layers except projection layers.\n \"\"\"\n\n def get_model(self, cfg=None, weights=None, verbose: bool = True):\n \"\"\"\n Return YOLOEModel initialized with specified config and weights.\n\n Args:\n cfg (dict | str, optional): Model configuration.\n weights (str, optional): Path to pretrained weights.\n verbose (bool): Whether to display model information.\n\n Returns:\n (YOLOEModel): Initialized model with frozen layers except for specific projection layers.\n \"\"\"\n # NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`.\n # NOTE: Following the official config, nc hard-coded to 80 for now.\n model = YOLOEModel(\n cfg[\"yaml_file\"] if isinstance(cfg, dict) else cfg,\n ch=self.data[\"channels\"],\n nc=self.data[\"nc\"],\n verbose=verbose and RANK == -1,\n )\n\n del model.model[-1].savpe\n\n assert weights is not None, \"Pretrained weights must be provided for linear probing.\"\n if weights:\n model.load(weights)\n\n model.eval()\n names = list(self.data[\"names\"].values())\n # NOTE: `get_text_pe` related to text model and YOLOEDetect.reprta,\n # it'd get correct results as long as loading proper pretrained weights.\n tpe = model.get_text_pe(names)\n model.set_classes(names, tpe)\n model.model[-1].fuse(model.pe) # fuse text embeddings to classify head\n model.model[-1].cv3[0][2] = deepcopy(model.model[-1].cv3[0][2]).requires_grad_(True)\n model.model[-1].cv3[1][2] = deepcopy(model.model[-1].cv3[1][2]).requires_grad_(True)\n model.model[-1].cv3[2][2] = deepcopy(model.model[-1].cv3[2][2]).requires_grad_(True)\n del model.pe\n model.train()\n\n return model", "chunk_type": "class", "name": "YOLOEPETrainer", "file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\train.py", "start_line": 110, "end_line": 161, "start_col": 0, "end_col": 20, "parent_name": null, "docstring": "Fine-tune YOLOE model using linear probing approach.\n\nThis trainer freezes most model layers and only trains specific projection layers for efficient\nfine-tuning on new datasets while preserving pretrained features.\n\nMethods:\n get_model: Initialize YOLOEModel with frozen layers except projection layers.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "itertools", "copy.copy", "copy.deepcopy", "pathlib.Path", "typing.Dict", "typing.List", "typing.Optional", "typing.Union", "torch", "ultralytics.data.YOLOConcatDataset", "ultralytics.data.build_yolo_dataset", "ultralytics.data.augment.LoadVisualPrompt", "ultralytics.models.yolo.detect.DetectionTrainer", "ultralytics.models.yolo.detect.DetectionValidator", "ultralytics.nn.tasks.YOLOEModel", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.LOGGER", "ultralytics.utils.RANK", "ultralytics.utils.torch_utils.de_parallel", "world.train_world.WorldTrainerFromScratch", "val.YOLOEDetectValidator", "DetectionTrainer" ], "chunk_id": "class_YOLOEPETrainer_12abeb65" }, { "content": "class YOLOETrainerFromScratch(YOLOETrainer, WorldTrainerFromScratch):\n \"\"\"\n Train YOLOE models from scratch with text embedding support.\n\n This trainer combines YOLOE training capabilities with world training features, enabling\n training from scratch with text embeddings and grounding datasets.\n\n Methods:\n build_dataset: Build datasets for training with grounding support.\n preprocess_batch: Process batches with text features.\n generate_text_embeddings: Generate and cache text embeddings for training.\n \"\"\"\n\n def build_dataset(self, img_path: Union[List[str], str], mode: str = \"train\", batch: Optional[int] = None):\n \"\"\"\n Build YOLO Dataset for training or validation.\n\n This method constructs appropriate datasets based on the mode and input paths, handling both\n standard YOLO datasets and grounding datasets with different formats.\n\n Args:\n img_path (List[str] | str): Path to the folder containing images or list of paths.\n mode (str): 'train' mode or 'val' mode, allowing customized augmentations for each mode.\n batch (int, optional): Size of batches, used for rectangular training/validation.\n\n Returns:\n (YOLOConcatDataset | Dataset): The constructed dataset for training or validation.\n \"\"\"\n return WorldTrainerFromScratch.build_dataset(self, img_path, mode, batch)\n\n def preprocess_batch(self, batch):\n \"\"\"Process batch for training, moving text features to the appropriate device.\"\"\"\n batch = DetectionTrainer.preprocess_batch(self, batch)\n\n texts = list(itertools.chain(*batch[\"texts\"]))\n txt_feats = torch.stack([self.text_embeddings[text] for text in texts]).to(self.device)\n txt_feats = txt_feats.reshape(len(batch[\"texts\"]), -1, txt_feats.shape[-1])\n batch[\"txt_feats\"] = txt_feats\n return batch\n\n def generate_text_embeddings(self, texts: List[str], batch: int, cache_dir: Path):\n \"\"\"\n Generate text embeddings for a list of text samples.\n\n Args:\n texts (List[str]): List of text samples to encode.\n batch (int): Batch size for processing.\n cache_dir (Path): Directory to save/load cached embeddings.\n\n Returns:\n (dict): Dictionary mapping text samples to their embeddings.\n \"\"\"\n model = \"mobileclip:blt\"\n cache_path = cache_dir / f\"text_embeddings_{model.replace(':', '_').replace('/', '_')}.pt\"\n if cache_path.exists():\n LOGGER.info(f\"Reading existed cache from '{cache_path}'\")\n txt_map = torch.load(cache_path, map_location=self.device)\n if sorted(txt_map.keys()) == sorted(texts):\n return txt_map\n LOGGER.info(f\"Caching text embeddings to '{cache_path}'\")\n assert self.model is not None\n txt_feats = de_parallel(self.model).get_text_pe(texts, batch, without_reprta=True, cache_clip_model=False)\n txt_map = dict(zip(texts, txt_feats.squeeze(0)))\n torch.save(txt_map, cache_path)\n return txt_map", "chunk_type": "class", "name": "YOLOETrainerFromScratch", "file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\train.py", "start_line": 164, "end_line": 228, "start_col": 0, "end_col": 22, "parent_name": null, "docstring": "Train YOLOE models from scratch with text embedding support.\n\nThis trainer combines YOLOE training capabilities with world training features, enabling\ntraining from scratch with text embeddings and grounding datasets.\n\nMethods:\n build_dataset: Build datasets for training with grounding support.\n preprocess_batch: Process batches with text features.\n generate_text_embeddings: Generate and cache text embeddings for training.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "itertools", "copy.copy", "copy.deepcopy", "pathlib.Path", "typing.Dict", "typing.List", "typing.Optional", "typing.Union", "torch", "ultralytics.data.YOLOConcatDataset", "ultralytics.data.build_yolo_dataset", "ultralytics.data.augment.LoadVisualPrompt", "ultralytics.models.yolo.detect.DetectionTrainer", "ultralytics.models.yolo.detect.DetectionValidator", "ultralytics.nn.tasks.YOLOEModel", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.LOGGER", "ultralytics.utils.RANK", "ultralytics.utils.torch_utils.de_parallel", "world.train_world.WorldTrainerFromScratch", "val.YOLOEDetectValidator", "YOLOETrainer", "WorldTrainerFromScratch" ], "chunk_id": "class_YOLOETrainerFromScratch_2a19afb9" }, { "content": "class YOLOEPEFreeTrainer(YOLOEPETrainer, YOLOETrainerFromScratch):\n \"\"\"\n Train prompt-free YOLOE model.\n\n This trainer combines linear probing capabilities with from-scratch training for prompt-free\n YOLOE models that don't require text prompts during inference.\n\n Methods:\n get_validator: Return standard DetectionValidator for validation.\n preprocess_batch: Preprocess batches without text features.\n set_text_embeddings: Set text embeddings for datasets (no-op for prompt-free).\n \"\"\"\n\n def get_validator(self):\n \"\"\"Return a DetectionValidator for YOLO model validation.\"\"\"\n self.loss_names = \"box\", \"cls\", \"dfl\"\n return DetectionValidator(\n self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks\n )\n\n def preprocess_batch(self, batch):\n \"\"\"Preprocess a batch of images for YOLOE training, adjusting formatting and dimensions as needed.\"\"\"\n batch = DetectionTrainer.preprocess_batch(self, batch)\n return batch\n\n def set_text_embeddings(self, datasets, batch: int):\n \"\"\"\n Set text embeddings for datasets to accelerate training by caching category names.\n\n This method collects unique category names from all datasets, generates text embeddings for them,\n and caches these embeddings to improve training efficiency. The embeddings are stored in a file\n in the parent directory of the first dataset's image path.\n\n Args:\n datasets (List[Dataset]): List of datasets containing category names to process.\n batch (int): Batch size for processing text embeddings.\n\n Notes:\n The method creates a dictionary mapping text samples to their embeddings and stores it\n at the path specified by 'cache_path'. If the cache file already exists, it will be loaded\n instead of regenerating the embeddings.\n \"\"\"\n pass", "chunk_type": "class", "name": "YOLOEPEFreeTrainer", "file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\train.py", "start_line": 231, "end_line": 273, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": "Train prompt-free YOLOE model.\n\nThis trainer combines linear probing capabilities with from-scratch training for prompt-free\nYOLOE models that don't require text prompts during inference.\n\nMethods:\n get_validator: Return standard DetectionValidator for validation.\n preprocess_batch: Preprocess batches without text features.\n set_text_embeddings: Set text embeddings for datasets (no-op for prompt-free).", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "itertools", "copy.copy", "copy.deepcopy", "pathlib.Path", "typing.Dict", "typing.List", "typing.Optional", "typing.Union", "torch", "ultralytics.data.YOLOConcatDataset", "ultralytics.data.build_yolo_dataset", "ultralytics.data.augment.LoadVisualPrompt", "ultralytics.models.yolo.detect.DetectionTrainer", "ultralytics.models.yolo.detect.DetectionValidator", "ultralytics.nn.tasks.YOLOEModel", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.LOGGER", "ultralytics.utils.RANK", "ultralytics.utils.torch_utils.de_parallel", "world.train_world.WorldTrainerFromScratch", "val.YOLOEDetectValidator", "YOLOEPETrainer", "YOLOETrainerFromScratch" ], "chunk_id": "class_YOLOEPEFreeTrainer_a87d96eb" }, { "content": "class YOLOEVPTrainer(YOLOETrainerFromScratch):\n \"\"\"\n Train YOLOE model with visual prompts.\n\n This trainer extends YOLOETrainerFromScratch to support visual prompt-based training,\n where visual cues are provided alongside images to guide the detection process.\n\n Methods:\n build_dataset: Build dataset with visual prompt loading transforms.\n preprocess_batch: Preprocess batches with visual prompts.\n \"\"\"\n\n def build_dataset(self, img_path: Union[List[str], str], mode: str = \"train\", batch: Optional[int] = None):\n \"\"\"\n Build YOLO Dataset for training or validation with visual prompts.\n\n Args:\n img_path (List[str] | str): Path to the folder containing images or list of paths.\n mode (str): 'train' mode or 'val' mode, allowing customized augmentations for each mode.\n batch (int, optional): Size of batches, used for rectangular training/validation.\n\n Returns:\n (Dataset): YOLO dataset configured for training or validation, with visual prompts for training mode.\n \"\"\"\n dataset = super().build_dataset(img_path, mode, batch)\n if isinstance(dataset, YOLOConcatDataset):\n for d in dataset.datasets:\n d.transforms.append(LoadVisualPrompt())\n else:\n dataset.transforms.append(LoadVisualPrompt())\n return dataset\n\n def _close_dataloader_mosaic(self):\n \"\"\"Close mosaic augmentation and add visual prompt loading to the training dataset.\"\"\"\n super()._close_dataloader_mosaic()\n if isinstance(self.train_loader.dataset, YOLOConcatDataset):\n for d in self.train_loader.dataset.datasets:\n d.transforms.append(LoadVisualPrompt())\n else:\n self.train_loader.dataset.transforms.append(LoadVisualPrompt())\n\n def preprocess_batch(self, batch):\n \"\"\"Preprocess a batch of images for YOLOE training, moving visual prompts to the appropriate device.\"\"\"\n batch = super().preprocess_batch(batch)\n batch[\"visuals\"] = batch[\"visuals\"].to(self.device)\n return batch", "chunk_type": "class", "name": "YOLOEVPTrainer", "file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\train.py", "start_line": 276, "end_line": 321, "start_col": 0, "end_col": 20, "parent_name": null, "docstring": "Train YOLOE model with visual prompts.\n\nThis trainer extends YOLOETrainerFromScratch to support visual prompt-based training,\nwhere visual cues are provided alongside images to guide the detection process.\n\nMethods:\n build_dataset: Build dataset with visual prompt loading transforms.\n preprocess_batch: Preprocess batches with visual prompts.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "itertools", "copy.copy", "copy.deepcopy", "pathlib.Path", "typing.Dict", "typing.List", "typing.Optional", "typing.Union", "torch", "ultralytics.data.YOLOConcatDataset", "ultralytics.data.build_yolo_dataset", "ultralytics.data.augment.LoadVisualPrompt", "ultralytics.models.yolo.detect.DetectionTrainer", "ultralytics.models.yolo.detect.DetectionValidator", "ultralytics.nn.tasks.YOLOEModel", "ultralytics.utils.DEFAULT_CFG", "ultralytics.utils.LOGGER", "ultralytics.utils.RANK", "ultralytics.utils.torch_utils.de_parallel", "world.train_world.WorldTrainerFromScratch", "val.YOLOEDetectValidator", "YOLOETrainerFromScratch" ], "chunk_id": "class_YOLOEVPTrainer_21b21ae1" }, { "content": "from copy import copy, deepcopy", "chunk_type": "import", "name": "copy, deepcopy", "file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\train_seg.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 31, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_copy, deepcopy_14f8fd7e" }, { "content": "from ultralytics.models.yolo.segment import SegmentationTrainer", "chunk_type": "import", "name": "SegmentationTrainer", "file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\train_seg.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 63, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_SegmentationTrainer_7b8fa4f0" }, { "content": "from ultralytics.nn.tasks import YOLOESegModel", "chunk_type": "import", "name": "YOLOESegModel", "file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\train_seg.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 46, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_YOLOESegModel_223f18e7" }, { "content": "from ultralytics.utils import RANK", "chunk_type": "import", "name": "RANK", "file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\train_seg.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 34, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_RANK_a43de328" }, { "content": "from .train import YOLOETrainer, YOLOETrainerFromScratch, YOLOEVPTrainer", "chunk_type": "import", "name": "YOLOETrainer, YOLOETrainerFromScratch, YOLOEVPTrainer", "file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\train_seg.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 72, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_YOLOETrainer, YOLOETrainerFromScratch, YOLOEVPTrainer_48db9426" }, { "content": "from .val import YOLOESegValidator", "chunk_type": "import", "name": "YOLOESegValidator", "file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\train_seg.py", "start_line": 10, "end_line": 10, "start_col": 0, "end_col": 34, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_YOLOESegValidator_73974eef" }, { "content": "class YOLOESegTrainer(YOLOETrainer, SegmentationTrainer):\n \"\"\"\n Trainer class for YOLOE segmentation models.\n\n This class combines YOLOETrainer and SegmentationTrainer to provide training functionality specifically for YOLOE\n segmentation models, enabling both object detection and instance segmentation capabilities.\n\n Attributes:\n cfg (dict): Configuration dictionary with training parameters.\n overrides (dict): Dictionary with parameter overrides.\n _callbacks (list): List of callback functions for training events.\n \"\"\"\n\n def get_model(self, cfg=None, weights=None, verbose=True):\n \"\"\"\n Return YOLOESegModel initialized with specified config and weights.\n\n Args:\n cfg (dict | str, optional): Model configuration dictionary or YAML file path.\n weights (str, optional): Path to pretrained weights file.\n verbose (bool): Whether to display model information.\n\n Returns:\n (YOLOESegModel): Initialized YOLOE segmentation model.\n \"\"\"\n # NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`.\n # NOTE: Following the official config, nc hard-coded to 80 for now.\n model = YOLOESegModel(\n cfg[\"yaml_file\"] if isinstance(cfg, dict) else cfg,\n ch=self.data[\"channels\"],\n nc=min(self.data[\"nc\"], 80),\n verbose=verbose and RANK == -1,\n )\n if weights:\n model.load(weights)\n\n return model\n\n def get_validator(self):\n \"\"\"\n Create and return a validator for YOLOE segmentation model evaluation.\n\n Returns:\n (YOLOESegValidator): Validator for YOLOE segmentation models.\n \"\"\"\n self.loss_names = \"box\", \"seg\", \"cls\", \"dfl\"\n return YOLOESegValidator(\n self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks\n )", "chunk_type": "class", "name": "YOLOESegTrainer", "file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\train_seg.py", "start_line": 13, "end_line": 61, "start_col": 0, "end_col": 9, "parent_name": null, "docstring": "Trainer class for YOLOE segmentation models.\n\nThis class combines YOLOETrainer and SegmentationTrainer to provide training functionality specifically for YOLOE\nsegmentation models, enabling both object detection and instance segmentation capabilities.\n\nAttributes:\n cfg (dict): Configuration dictionary with training parameters.\n overrides (dict): Dictionary with parameter overrides.\n _callbacks (list): List of callback functions for training events.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "copy.copy", "copy.deepcopy", "ultralytics.models.yolo.segment.SegmentationTrainer", "ultralytics.nn.tasks.YOLOESegModel", "ultralytics.utils.RANK", "train.YOLOETrainer", "train.YOLOETrainerFromScratch", "train.YOLOEVPTrainer", "val.YOLOESegValidator", "YOLOETrainer", "SegmentationTrainer" ], "chunk_id": "class_YOLOESegTrainer_902592da" }, { "content": "class YOLOEPESegTrainer(SegmentationTrainer):\n \"\"\"\n Fine-tune YOLOESeg model in linear probing way.\n\n This trainer specializes in fine-tuning YOLOESeg models using a linear probing approach, which involves freezing\n most of the model and only training specific layers for efficient adaptation to new tasks.\n\n Attributes:\n data (dict): Dataset configuration containing channels, class names, and number of classes.\n \"\"\"\n\n def get_model(self, cfg=None, weights=None, verbose=True):\n \"\"\"\n Return YOLOESegModel initialized with specified config and weights for linear probing.\n\n Args:\n cfg (dict | str, optional): Model configuration dictionary or YAML file path.\n weights (str, optional): Path to pretrained weights file.\n verbose (bool): Whether to display model information.\n\n Returns:\n (YOLOESegModel): Initialized YOLOE segmentation model configured for linear probing.\n \"\"\"\n # NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`.\n # NOTE: Following the official config, nc hard-coded to 80 for now.\n model = YOLOESegModel(\n cfg[\"yaml_file\"] if isinstance(cfg, dict) else cfg,\n ch=self.data[\"channels\"],\n nc=self.data[\"nc\"],\n verbose=verbose and RANK == -1,\n )\n\n del model.model[-1].savpe\n\n assert weights is not None, \"Pretrained weights must be provided for linear probing.\"\n if weights:\n model.load(weights)\n\n model.eval()\n names = list(self.data[\"names\"].values())\n # NOTE: `get_text_pe` related to text model and YOLOEDetect.reprta,\n # it'd get correct results as long as loading proper pretrained weights.\n tpe = model.get_text_pe(names)\n model.set_classes(names, tpe)\n model.model[-1].fuse(model.pe)\n model.model[-1].cv3[0][2] = deepcopy(model.model[-1].cv3[0][2]).requires_grad_(True)\n model.model[-1].cv3[1][2] = deepcopy(model.model[-1].cv3[1][2]).requires_grad_(True)\n model.model[-1].cv3[2][2] = deepcopy(model.model[-1].cv3[2][2]).requires_grad_(True)\n del model.pe\n model.train()\n\n return model", "chunk_type": "class", "name": "YOLOEPESegTrainer", "file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\train_seg.py", "start_line": 64, "end_line": 115, "start_col": 0, "end_col": 20, "parent_name": null, "docstring": "Fine-tune YOLOESeg model in linear probing way.\n\nThis trainer specializes in fine-tuning YOLOESeg models using a linear probing approach, which involves freezing\nmost of the model and only training specific layers for efficient adaptation to new tasks.\n\nAttributes:\n data (dict): Dataset configuration containing channels, class names, and number of classes.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "copy.copy", "copy.deepcopy", "ultralytics.models.yolo.segment.SegmentationTrainer", "ultralytics.nn.tasks.YOLOESegModel", "ultralytics.utils.RANK", "train.YOLOETrainer", "train.YOLOETrainerFromScratch", "train.YOLOEVPTrainer", "val.YOLOESegValidator", "SegmentationTrainer" ], "chunk_id": "class_YOLOEPESegTrainer_87305228" }, { "content": "class YOLOESegTrainerFromScratch(YOLOETrainerFromScratch, YOLOESegTrainer):\n \"\"\"Trainer for YOLOE segmentation models trained from scratch without pretrained weights.\"\"\"\n\n pass", "chunk_type": "class", "name": "YOLOESegTrainerFromScratch", "file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\train_seg.py", "start_line": 118, "end_line": 121, "start_col": 0, "end_col": 8, "parent_name": null, "docstring": "Trainer for YOLOE segmentation models trained from scratch without pretrained weights.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "copy.copy", "copy.deepcopy", "ultralytics.models.yolo.segment.SegmentationTrainer", "ultralytics.nn.tasks.YOLOESegModel", "ultralytics.utils.RANK", "train.YOLOETrainer", "train.YOLOETrainerFromScratch", "train.YOLOEVPTrainer", "val.YOLOESegValidator", "YOLOETrainerFromScratch", "YOLOESegTrainer" ], "chunk_id": "class_YOLOESegTrainerFromScratch_ac610604" }, { "content": "class YOLOESegVPTrainer(YOLOEVPTrainer, YOLOESegTrainerFromScratch):\n \"\"\"Trainer for YOLOE segmentation models with Vision Prompt (VP) capabilities.\"\"\"\n\n pass", "chunk_type": "class", "name": "YOLOESegVPTrainer", "file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\train_seg.py", "start_line": 124, "end_line": 127, "start_col": 0, "end_col": 8, "parent_name": null, "docstring": "Trainer for YOLOE segmentation models with Vision Prompt (VP) capabilities.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "copy.copy", "copy.deepcopy", "ultralytics.models.yolo.segment.SegmentationTrainer", "ultralytics.nn.tasks.YOLOESegModel", "ultralytics.utils.RANK", "train.YOLOETrainer", "train.YOLOETrainerFromScratch", "train.YOLOEVPTrainer", "val.YOLOESegValidator", "YOLOEVPTrainer", "YOLOESegTrainerFromScratch" ], "chunk_id": "class_YOLOESegVPTrainer_848dc5cf" }, { "content": "from copy import deepcopy", "chunk_type": "import", "name": "deepcopy", "file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\val.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 25, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_deepcopy_f9e40614" }, { "content": "from pathlib import Path", "chunk_type": "import", "name": "Path", "file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\val.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Path_9f78b6b9" }, { "content": "from typing import Any, Dict, Optional, Union", "chunk_type": "import", "name": "Any, Dict, Optional, Union", "file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\val.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 45, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Any, Dict, Optional, Union_05bd7c6f" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\val.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_34502e9c" }, { "content": "from torch.nn import functional as F", "chunk_type": "import", "name": "functional", "file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\val.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 36, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_functional_5511c916" }, { "content": "from ultralytics.data import YOLOConcatDataset, build_dataloader, build_yolo_dataset", "chunk_type": "import", "name": "YOLOConcatDataset, build_dataloader, build_yolo_dataset", "file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\val.py", "start_line": 10, "end_line": 10, "start_col": 0, "end_col": 84, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_YOLOConcatDataset, build_dataloader, build_yolo_dataset_3e8eb75d" }, { "content": "from ultralytics.data.augment import LoadVisualPrompt", "chunk_type": "import", "name": "LoadVisualPrompt", "file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\val.py", "start_line": 11, "end_line": 11, "start_col": 0, "end_col": 53, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_LoadVisualPrompt_7f66e959" }, { "content": "from ultralytics.data.utils import check_det_dataset", "chunk_type": "import", "name": "check_det_dataset", "file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\val.py", "start_line": 12, "end_line": 12, "start_col": 0, "end_col": 52, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_check_det_dataset_76dcfbb2" }, { "content": "from ultralytics.models.yolo.detect import DetectionValidator", "chunk_type": "import", "name": "DetectionValidator", "file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\val.py", "start_line": 13, "end_line": 13, "start_col": 0, "end_col": 61, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_DetectionValidator_60711022" }, { "content": "from ultralytics.models.yolo.segment import SegmentationValidator", "chunk_type": "import", "name": "SegmentationValidator", "file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\val.py", "start_line": 14, "end_line": 14, "start_col": 0, "end_col": 65, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_SegmentationValidator_c1531403" }, { "content": "from ultralytics.nn.modules.head import YOLOEDetect", "chunk_type": "import", "name": "YOLOEDetect", "file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\val.py", "start_line": 15, "end_line": 15, "start_col": 0, "end_col": 51, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_YOLOEDetect_bcda8aa3" }, { "content": "from ultralytics.nn.tasks import YOLOEModel", "chunk_type": "import", "name": "YOLOEModel", "file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\val.py", "start_line": 16, "end_line": 16, "start_col": 0, "end_col": 43, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_YOLOEModel_589c4e76" }, { "content": "from ultralytics.utils import LOGGER, TQDM", "chunk_type": "import", "name": "LOGGER, TQDM", "file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\val.py", "start_line": 17, "end_line": 17, "start_col": 0, "end_col": 42, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_LOGGER, TQDM_dde9a176" }, { "content": "from ultralytics.utils.torch_utils import select_device, smart_inference_mode", "chunk_type": "import", "name": "select_device, smart_inference_mode", "file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\val.py", "start_line": 18, "end_line": 18, "start_col": 0, "end_col": 77, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_select_device, smart_inference_mode_c3de7d32" }, { "content": "class YOLOEDetectValidator(DetectionValidator):\n \"\"\"\n A validator class for YOLOE detection models that handles both text and visual prompt embeddings.\n\n This class extends DetectionValidator to provide specialized validation functionality for YOLOE models.\n It supports validation using either text prompts or visual prompt embeddings extracted from training samples,\n enabling flexible evaluation strategies for prompt-based object detection.\n\n Attributes:\n device (torch.device): The device on which validation is performed.\n args (namespace): Configuration arguments for validation.\n dataloader (DataLoader): DataLoader for validation data.\n\n Methods:\n get_visual_pe: Extract visual prompt embeddings from training samples.\n preprocess: Preprocess batch data ensuring visuals are on the same device as images.\n get_vpe_dataloader: Create a dataloader for LVIS training visual prompt samples.\n __call__: Run validation using either text or visual prompt embeddings.\n\n Examples:\n Validate with text prompts\n >>> validator = YOLOEDetectValidator()\n >>> stats = validator(model=model, load_vp=False)\n\n Validate with visual prompts\n >>> stats = validator(model=model, refer_data=\"path/to/data.yaml\", load_vp=True)\n \"\"\"\n\n @smart_inference_mode()\n def get_visual_pe(self, dataloader: torch.utils.data.DataLoader, model: YOLOEModel) -> torch.Tensor:\n \"\"\"\n Extract visual prompt embeddings from training samples.\n\n This method processes a dataloader to compute visual prompt embeddings for each class using a YOLOE model.\n It normalizes the embeddings and handles cases where no samples exist for a class by setting their\n embeddings to zero.\n\n Args:\n dataloader (torch.utils.data.DataLoader): The dataloader providing training samples.\n model (YOLOEModel): The YOLOE model from which to extract visual prompt embeddings.\n\n Returns:\n (torch.Tensor): Visual prompt embeddings with shape (1, num_classes, embed_dim).\n \"\"\"\n assert isinstance(model, YOLOEModel)\n names = [name.split(\"/\", 1)[0] for name in list(dataloader.dataset.data[\"names\"].values())]\n visual_pe = torch.zeros(len(names), model.model[-1].embed, device=self.device)\n cls_visual_num = torch.zeros(len(names))\n\n desc = \"Get visual prompt embeddings from samples\"\n\n # Count samples per class\n for batch in dataloader:\n cls = batch[\"cls\"].squeeze(-1).to(torch.int).unique()\n count = torch.bincount(cls, minlength=len(names))\n cls_visual_num += count\n\n cls_visual_num = cls_visual_num.to(self.device)\n\n # Extract visual prompt embeddings\n pbar = TQDM(dataloader, total=len(dataloader), desc=desc)\n for batch in pbar:\n batch = self.preprocess(batch)\n preds = model.get_visual_pe(batch[\"img\"], visual=batch[\"visuals\"]) # (B, max_n, embed_dim)\n\n batch_idx = batch[\"batch_idx\"]\n for i in range(preds.shape[0]):\n cls = batch[\"cls\"][batch_idx == i].squeeze(-1).to(torch.int).unique(sorted=True)\n pad_cls = torch.ones(preds.shape[1], device=self.device) * -1\n pad_cls[: len(cls)] = cls\n for c in cls:\n visual_pe[c] += preds[i][pad_cls == c].sum(0) / cls_visual_num[c]\n\n # Normalize embeddings for classes with samples, set others to zero\n visual_pe[cls_visual_num != 0] = F.normalize(visual_pe[cls_visual_num != 0], dim=-1, p=2)\n visual_pe[cls_visual_num == 0] = 0\n return visual_pe.unsqueeze(0)\n\n def preprocess(self, batch: Dict[str, Any]) -> Dict[str, Any]:\n \"\"\"Preprocess batch data, ensuring visuals are on the same device as images.\"\"\"\n batch = super().preprocess(batch)\n if \"visuals\" in batch:\n batch[\"visuals\"] = batch[\"visuals\"].to(batch[\"img\"].device)\n return batch\n\n def get_vpe_dataloader(self, data: Dict[str, Any]) -> torch.utils.data.DataLoader:\n \"\"\"\n Create a dataloader for LVIS training visual prompt samples.\n\n This method prepares a dataloader for visual prompt embeddings (VPE) using the specified dataset.\n It applies necessary transformations including LoadVisualPrompt and configurations to the dataset\n for validation purposes.\n\n Args:\n data (dict): Dataset configuration dictionary containing paths and settings.\n\n Returns:\n (torch.utils.data.DataLoader): The dataloader for visual prompt samples.\n \"\"\"\n dataset = build_yolo_dataset(\n self.args,\n data.get(self.args.split, data.get(\"val\")),\n self.args.batch,\n data,\n mode=\"val\",\n rect=False,\n )\n if isinstance(dataset, YOLOConcatDataset):\n for d in dataset.datasets:\n d.transforms.append(LoadVisualPrompt())\n else:\n dataset.transforms.append(LoadVisualPrompt())\n return build_dataloader(\n dataset,\n self.args.batch,\n self.args.workers,\n shuffle=False,\n rank=-1,\n )\n\n @smart_inference_mode()\n def __call__(\n self,\n trainer: Optional[Any] = None,\n model: Optional[Union[YOLOEModel, str]] = None,\n refer_data: Optional[str] = None,\n load_vp: bool = False,\n ) -> Dict[str, Any]:\n \"\"\"\n Run validation on the model using either text or visual prompt embeddings.\n\n This method validates the model using either text prompts or visual prompts, depending on the load_vp flag.\n It supports validation during training (using a trainer object) or standalone validation with a provided\n model. For visual prompts, reference data can be specified to extract embeddings from a different dataset.\n\n Args:\n trainer (object, optional): Trainer object containing the model and device.\n model (YOLOEModel | str, optional): Model to validate. Required if trainer is not provided.\n refer_data (str, optional): Path to reference data for visual prompts.\n load_vp (bool): Whether to load visual prompts. If False, text prompts are used.\n\n Returns:\n (dict): Validation statistics containing metrics computed during validation.\n \"\"\"\n if trainer is not None:\n self.device = trainer.device\n model = trainer.ema.ema\n names = [name.split(\"/\", 1)[0] for name in list(self.dataloader.dataset.data[\"names\"].values())]\n\n if load_vp:\n LOGGER.info(\"Validate using the visual prompt.\")\n self.args.half = False\n # Directly use the same dataloader for visual embeddings extracted during training\n vpe = self.get_visual_pe(self.dataloader, model)\n model.set_classes(names, vpe)\n else:\n LOGGER.info(\"Validate using the text prompt.\")\n tpe = model.get_text_pe(names)\n model.set_classes(names, tpe)\n stats = super().__call__(trainer, model)\n else:\n if refer_data is not None:\n assert load_vp, \"Refer data is only used for visual prompt validation.\"\n self.device = select_device(self.args.device)\n\n if isinstance(model, (str, Path)):\n from ultralytics.nn.tasks import attempt_load_weights\n\n model = attempt_load_weights(model, device=self.device, inplace=True)\n model.eval().to(self.device)\n data = check_det_dataset(refer_data or self.args.data)\n names = [name.split(\"/\", 1)[0] for name in list(data[\"names\"].values())]\n\n if load_vp:\n LOGGER.info(\"Validate using the visual prompt.\")\n self.args.half = False\n # TODO: need to check if the names from refer data is consistent with the evaluated dataset\n # could use same dataset or refer to extract visual prompt embeddings\n dataloader = self.get_vpe_dataloader(data)\n vpe = self.get_visual_pe(dataloader, model)\n model.set_classes(names, vpe)\n stats = super().__call__(model=deepcopy(model))\n elif isinstance(model.model[-1], YOLOEDetect) and hasattr(model.model[-1], \"lrpc\"): # prompt-free\n return super().__call__(trainer, model)\n else:\n LOGGER.info(\"Validate using the text prompt.\")\n tpe = model.get_text_pe(names)\n model.set_classes(names, tpe)\n stats = super().__call__(model=deepcopy(model))\n return stats", "chunk_type": "class", "name": "YOLOEDetectValidator", "file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\val.py", "start_line": 21, "end_line": 210, "start_col": 0, "end_col": 20, "parent_name": null, "docstring": "A validator class for YOLOE detection models that handles both text and visual prompt embeddings.\n\nThis class extends DetectionValidator to provide specialized validation functionality for YOLOE models.\nIt supports validation using either text prompts or visual prompt embeddings extracted from training samples,\nenabling flexible evaluation strategies for prompt-based object detection.\n\nAttributes:\n device (torch.device): The device on which validation is performed.\n args (namespace): Configuration arguments for validation.\n dataloader (DataLoader): DataLoader for validation data.\n\nMethods:\n get_visual_pe: Extract visual prompt embeddings from training samples.\n preprocess: Preprocess batch data ensuring visuals are on the same device as images.\n get_vpe_dataloader: Create a dataloader for LVIS training visual prompt samples.\n __call__: Run validation using either text or visual prompt embeddings.\n\nExamples:\n Validate with text prompts\n >>> validator = YOLOEDetectValidator()\n >>> stats = validator(model=model, load_vp=False)\n\n Validate with visual prompts\n >>> stats = validator(model=model, refer_data=\"path/to/data.yaml\", load_vp=True)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "copy.deepcopy", "pathlib.Path", "typing.Any", "typing.Dict", "typing.Optional", "typing.Union", "torch", "torch.nn.functional", "ultralytics.data.YOLOConcatDataset", "ultralytics.data.build_dataloader", "ultralytics.data.build_yolo_dataset", "ultralytics.data.augment.LoadVisualPrompt", "ultralytics.data.utils.check_det_dataset", "ultralytics.models.yolo.detect.DetectionValidator", "ultralytics.models.yolo.segment.SegmentationValidator", "ultralytics.nn.modules.head.YOLOEDetect", "ultralytics.nn.tasks.YOLOEModel", "ultralytics.utils.LOGGER", "ultralytics.utils.TQDM", "ultralytics.utils.torch_utils.select_device", "ultralytics.utils.torch_utils.smart_inference_mode", "ultralytics.nn.tasks.attempt_load_weights", "DetectionValidator" ], "chunk_id": "class_YOLOEDetectValidator_f61b65ec" }, { "content": "class YOLOESegValidator(YOLOEDetectValidator, SegmentationValidator):\n \"\"\"YOLOE segmentation validator that supports both text and visual prompt embeddings.\"\"\"\n\n pass", "chunk_type": "class", "name": "YOLOESegValidator", "file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\val.py", "start_line": 213, "end_line": 216, "start_col": 0, "end_col": 8, "parent_name": null, "docstring": "YOLOE segmentation validator that supports both text and visual prompt embeddings.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "copy.deepcopy", "pathlib.Path", "typing.Any", "typing.Dict", "typing.Optional", "typing.Union", "torch", "torch.nn.functional", "ultralytics.data.YOLOConcatDataset", "ultralytics.data.build_dataloader", "ultralytics.data.build_yolo_dataset", "ultralytics.data.augment.LoadVisualPrompt", "ultralytics.data.utils.check_det_dataset", "ultralytics.models.yolo.detect.DetectionValidator", "ultralytics.models.yolo.segment.SegmentationValidator", "ultralytics.nn.modules.head.YOLOEDetect", "ultralytics.nn.tasks.YOLOEModel", "ultralytics.utils.LOGGER", "ultralytics.utils.TQDM", "ultralytics.utils.torch_utils.select_device", "ultralytics.utils.torch_utils.smart_inference_mode", "ultralytics.nn.tasks.attempt_load_weights", "YOLOEDetectValidator", "SegmentationValidator" ], "chunk_id": "class_YOLOESegValidator_bbed2d6f" }, { "content": "from .predict import YOLOEVPDetectPredictor, YOLOEVPSegPredictor", "chunk_type": "import", "name": "YOLOEVPDetectPredictor, YOLOEVPSegPredictor", "file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\__init__.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 64, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_YOLOEVPDetectPredictor, YOLOEVPSegPredictor_08ad326a" }, { "content": "from .train import YOLOEPEFreeTrainer, YOLOEPETrainer, YOLOETrainer, YOLOETrainerFromScratch, YOLOEVPTrainer", "chunk_type": "import", "name": "YOLOEPEFreeTrainer, YOLOEPETrainer, YOLOETrainer, YOLOETrainerFromScratch, YOLOEVPTrainer", "file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\__init__.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 108, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_YOLOEPEFreeTrainer, YOLOEPETrainer, YOLOETrainer, YOLOETrainerFromScratch, YOLOEVPTrainer_7dca38c0" }, { "content": "from .train_seg import YOLOEPESegTrainer, YOLOESegTrainer, YOLOESegTrainerFromScratch, YOLOESegVPTrainer", "chunk_type": "import", "name": "YOLOEPESegTrainer, YOLOESegTrainer, YOLOESegTrainerFromScratch, YOLOESegVPTrainer", "file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\__init__.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 104, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_YOLOEPESegTrainer, YOLOESegTrainer, YOLOESegTrainerFromScratch, YOLOESegVPTrainer_d356eaa9" }, { "content": "from .val import YOLOEDetectValidator, YOLOESegValidator", "chunk_type": "import", "name": "YOLOEDetectValidator, YOLOESegValidator", "file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\__init__.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 56, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_YOLOEDetectValidator, YOLOESegValidator_ba883fe7" }, { "content": "__all__ = [\n \"YOLOETrainer\",\n \"YOLOEPETrainer\",\n \"YOLOESegTrainer\",\n \"YOLOEDetectValidator\",\n \"YOLOESegValidator\",\n \"YOLOEPESegTrainer\",\n \"YOLOESegTrainerFromScratch\",\n \"YOLOESegVPTrainer\",\n \"YOLOEVPTrainer\",\n \"YOLOEPEFreeTrainer\",\n \"YOLOEVPDetectPredictor\",\n \"YOLOEVPSegPredictor\",\n \"YOLOETrainerFromScratch\",\n]", "chunk_type": "variable", "name": "__all__", "file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\__init__.py", "start_line": 8, "end_line": 22, "start_col": 0, "end_col": 1, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable___all___7ce3e0c1" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\ultralytics\\nn\\modules\\activation.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_5eac896e" }, { "content": "import torch.nn as nn", "chunk_type": "import", "name": "torch.nn", "file_path": "ultralytics\\ultralytics\\nn\\modules\\activation.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 21, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch.nn_cf4e57aa" }, { "content": "class AGLU(nn.Module):\n \"\"\"\n Unified activation function module from AGLU.\n\n This class implements a parameterized activation function with learnable parameters lambda and kappa, based on the\n AGLU (Adaptive Gated Linear Unit) approach.\n\n Attributes:\n act (nn.Softplus): Softplus activation function with negative beta.\n lambd (nn.Parameter): Learnable lambda parameter initialized with uniform distribution.\n kappa (nn.Parameter): Learnable kappa parameter initialized with uniform distribution.\n\n Methods:\n forward: Compute the forward pass of the Unified activation function.\n\n Examples:\n >>> import torch\n >>> m = AGLU()\n >>> input = torch.randn(2)\n >>> output = m(input)\n >>> print(output.shape)\n torch.Size([2])\n\n References:\n https://github.com/kostas1515/AGLU\n \"\"\"\n\n def __init__(self, device=None, dtype=None) -> None:\n \"\"\"Initialize the Unified activation function with learnable parameters.\"\"\"\n super().__init__()\n self.act = nn.Softplus(beta=-1.0)\n self.lambd = nn.Parameter(nn.init.uniform_(torch.empty(1, device=device, dtype=dtype))) # lambda parameter\n self.kappa = nn.Parameter(nn.init.uniform_(torch.empty(1, device=device, dtype=dtype))) # kappa parameter\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"\n Apply the Adaptive Gated Linear Unit (AGLU) activation function.\n\n This forward method implements the AGLU activation function with learnable parameters lambda and kappa.\n The function applies a transformation that adaptively combines linear and non-linear components.\n\n Args:\n x (torch.Tensor): Input tensor to apply the activation function to.\n\n Returns:\n (torch.Tensor): Output tensor after applying the AGLU activation function, with the same shape as the input.\n \"\"\"\n lam = torch.clamp(self.lambd, min=0.0001) # Clamp lambda to avoid division by zero\n return torch.exp((1 / lam) * self.act((self.kappa * x) - torch.log(lam)))", "chunk_type": "class", "name": "AGLU", "file_path": "ultralytics\\ultralytics\\nn\\modules\\activation.py", "start_line": 8, "end_line": 56, "start_col": 0, "end_col": 81, "parent_name": null, "docstring": "Unified activation function module from AGLU.\n\nThis class implements a parameterized activation function with learnable parameters lambda and kappa, based on the\nAGLU (Adaptive Gated Linear Unit) approach.\n\nAttributes:\n act (nn.Softplus): Softplus activation function with negative beta.\n lambd (nn.Parameter): Learnable lambda parameter initialized with uniform distribution.\n kappa (nn.Parameter): Learnable kappa parameter initialized with uniform distribution.\n\nMethods:\n forward: Compute the forward pass of the Unified activation function.\n\nExamples:\n >>> import torch\n >>> m = AGLU()\n >>> input = torch.randn(2)\n >>> output = m(input)\n >>> print(output.shape)\n torch.Size([2])\n\nReferences:\n https://github.com/kostas1515/AGLU", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "torch", "torch.nn", "nn.Module" ], "chunk_id": "class_AGLU_856d70d7" }, { "content": "from typing import List, Optional, Tuple", "chunk_type": "import", "name": "List, Optional, Tuple", "file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 40, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_List, Optional, Tuple_3be83d74" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_8d4ac29b" }, { "content": "import torch.nn as nn", "chunk_type": "import", "name": "torch.nn", "file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 21, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch.nn_f1a1c3a8" }, { "content": "import torch.nn.functional as F", "chunk_type": "import", "name": "torch.nn.functional", "file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 31, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch.nn.functional_9c8fb0f7" }, { "content": "from ultralytics.utils.torch_utils import fuse_conv_and_bn", "chunk_type": "import", "name": "fuse_conv_and_bn", "file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py", "start_line": 10, "end_line": 10, "start_col": 0, "end_col": 58, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_fuse_conv_and_bn_5c231990" }, { "content": "from .conv import Conv, DWConv, GhostConv, LightConv, RepConv, autopad", "chunk_type": "import", "name": "Conv, DWConv, GhostConv, LightConv, RepConv, autopad", "file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py", "start_line": 12, "end_line": 12, "start_col": 0, "end_col": 70, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Conv, DWConv, GhostConv, LightConv, RepConv, autopad_4db33685" }, { "content": "from .transformer import TransformerBlock", "chunk_type": "import", "name": "TransformerBlock", "file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py", "start_line": 13, "end_line": 13, "start_col": 0, "end_col": 41, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_TransformerBlock_5f97c6ea" }, { "content": "__all__ = (\n \"DFL\",\n \"HGBlock\",\n \"HGStem\",\n \"SPP\",\n \"SPPF\",\n \"C1\",\n \"C2\",\n \"C3\",\n \"C2f\",\n \"C2fAttn\",\n \"ImagePoolingAttn\",\n \"ContrastiveHead\",\n \"BNContrastiveHead\",\n \"C3x\",\n \"C3TR\",\n \"C3Ghost\",\n \"GhostBottleneck\",\n \"Bottleneck\",\n \"BottleneckCSP\",\n \"Proto\",\n \"RepC3\",\n \"ResNetLayer\",\n \"RepNCSPELAN4\",\n \"ELAN1\",\n \"ADown\",\n \"AConv\",\n \"SPPELAN\",\n \"CBFuse\",\n \"CBLinear\",\n \"C3k2\",\n \"C2fPSA\",\n \"C2PSA\",\n \"RepVGGDW\",\n \"CIB\",\n \"C2fCIB\",\n \"Attention\",\n \"PSA\",\n \"SCDown\",\n \"TorchVision\",\n)", "chunk_type": "variable", "name": "__all__", "file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py", "start_line": 15, "end_line": 55, "start_col": 0, "end_col": 1, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable___all___2c799b66" }, { "content": "class DFL(nn.Module):\n \"\"\"\n Integral module of Distribution Focal Loss (DFL).\n\n Proposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391\n \"\"\"\n\n def __init__(self, c1: int = 16):\n \"\"\"\n Initialize a convolutional layer with a given number of input channels.\n\n Args:\n c1 (int): Number of input channels.\n \"\"\"\n super().__init__()\n self.conv = nn.Conv2d(c1, 1, 1, bias=False).requires_grad_(False)\n x = torch.arange(c1, dtype=torch.float)\n self.conv.weight.data[:] = nn.Parameter(x.view(1, c1, 1, 1))\n self.c1 = c1\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Apply the DFL module to input tensor and return transformed output.\"\"\"\n b, _, a = x.shape # batch, channels, anchors\n return self.conv(x.view(b, 4, self.c1, a).transpose(2, 1).softmax(1)).view(b, 4, a)", "chunk_type": "class", "name": "DFL", "file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py", "start_line": 58, "end_line": 81, "start_col": 0, "end_col": 91, "parent_name": null, "docstring": "Integral module of Distribution Focal Loss (DFL).\n\nProposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.List", "typing.Optional", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "conv.Conv", "conv.DWConv", "conv.GhostConv", "conv.LightConv", "conv.RepConv", "conv.autopad", "transformer.TransformerBlock", "torchvision", "nn.Module" ], "chunk_id": "class_DFL_1fd1ef42" }, { "content": "class Proto(nn.Module):\n \"\"\"Ultralytics YOLO models mask Proto module for segmentation models.\"\"\"\n\n def __init__(self, c1: int, c_: int = 256, c2: int = 32):\n \"\"\"\n Initialize the Ultralytics YOLO models mask Proto module with specified number of protos and masks.\n\n Args:\n c1 (int): Input channels.\n c_ (int): Intermediate channels.\n c2 (int): Output channels (number of protos).\n \"\"\"\n super().__init__()\n self.cv1 = Conv(c1, c_, k=3)\n self.upsample = nn.ConvTranspose2d(c_, c_, 2, 2, 0, bias=True) # nn.Upsample(scale_factor=2, mode='nearest')\n self.cv2 = Conv(c_, c_, k=3)\n self.cv3 = Conv(c_, c2)\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Perform a forward pass through layers using an upsampled input image.\"\"\"\n return self.cv3(self.cv2(self.upsample(self.cv1(x))))", "chunk_type": "class", "name": "Proto", "file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py", "start_line": 85, "end_line": 105, "start_col": 0, "end_col": 61, "parent_name": null, "docstring": "Ultralytics YOLO models mask Proto module for segmentation models.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.List", "typing.Optional", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "conv.Conv", "conv.DWConv", "conv.GhostConv", "conv.LightConv", "conv.RepConv", "conv.autopad", "transformer.TransformerBlock", "torchvision", "nn.Module" ], "chunk_id": "class_Proto_d5d6f109" }, { "content": "class HGStem(nn.Module):\n \"\"\"\n StemBlock of PPHGNetV2 with 5 convolutions and one maxpool2d.\n\n https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/backbones/hgnet_v2.py\n \"\"\"\n\n def __init__(self, c1: int, cm: int, c2: int):\n \"\"\"\n Initialize the StemBlock of PPHGNetV2.\n\n Args:\n c1 (int): Input channels.\n cm (int): Middle channels.\n c2 (int): Output channels.\n \"\"\"\n super().__init__()\n self.stem1 = Conv(c1, cm, 3, 2, act=nn.ReLU())\n self.stem2a = Conv(cm, cm // 2, 2, 1, 0, act=nn.ReLU())\n self.stem2b = Conv(cm // 2, cm, 2, 1, 0, act=nn.ReLU())\n self.stem3 = Conv(cm * 2, cm, 3, 2, act=nn.ReLU())\n self.stem4 = Conv(cm, c2, 1, 1, act=nn.ReLU())\n self.pool = nn.MaxPool2d(kernel_size=2, stride=1, padding=0, ceil_mode=True)\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Forward pass of a PPHGNetV2 backbone layer.\"\"\"\n x = self.stem1(x)\n x = F.pad(x, [0, 1, 0, 1])\n x2 = self.stem2a(x)\n x2 = F.pad(x2, [0, 1, 0, 1])\n x2 = self.stem2b(x2)\n x1 = self.pool(x)\n x = torch.cat([x1, x2], dim=1)\n x = self.stem3(x)\n x = self.stem4(x)\n return x", "chunk_type": "class", "name": "HGStem", "file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py", "start_line": 108, "end_line": 143, "start_col": 0, "end_col": 16, "parent_name": null, "docstring": "StemBlock of PPHGNetV2 with 5 convolutions and one maxpool2d.\n\nhttps://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/backbones/hgnet_v2.py", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.List", "typing.Optional", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "conv.Conv", "conv.DWConv", "conv.GhostConv", "conv.LightConv", "conv.RepConv", "conv.autopad", "transformer.TransformerBlock", "torchvision", "nn.Module" ], "chunk_id": "class_HGStem_49c65caf" }, { "content": "class HGBlock(nn.Module):\n \"\"\"\n HG_Block of PPHGNetV2 with 2 convolutions and LightConv.\n\n https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/backbones/hgnet_v2.py\n \"\"\"\n\n def __init__(\n self,\n c1: int,\n cm: int,\n c2: int,\n k: int = 3,\n n: int = 6,\n lightconv: bool = False,\n shortcut: bool = False,\n act: nn.Module = nn.ReLU(),\n ):\n \"\"\"\n Initialize HGBlock with specified parameters.\n\n Args:\n c1 (int): Input channels.\n cm (int): Middle channels.\n c2 (int): Output channels.\n k (int): Kernel size.\n n (int): Number of LightConv or Conv blocks.\n lightconv (bool): Whether to use LightConv.\n shortcut (bool): Whether to use shortcut connection.\n act (nn.Module): Activation function.\n \"\"\"\n super().__init__()\n block = LightConv if lightconv else Conv\n self.m = nn.ModuleList(block(c1 if i == 0 else cm, cm, k=k, act=act) for i in range(n))\n self.sc = Conv(c1 + n * cm, c2 // 2, 1, 1, act=act) # squeeze conv\n self.ec = Conv(c2 // 2, c2, 1, 1, act=act) # excitation conv\n self.add = shortcut and c1 == c2\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Forward pass of a PPHGNetV2 backbone layer.\"\"\"\n y = [x]\n y.extend(m(y[-1]) for m in self.m)\n y = self.ec(self.sc(torch.cat(y, 1)))\n return y + x if self.add else y", "chunk_type": "class", "name": "HGBlock", "file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py", "start_line": 146, "end_line": 189, "start_col": 0, "end_col": 39, "parent_name": null, "docstring": "HG_Block of PPHGNetV2 with 2 convolutions and LightConv.\n\nhttps://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/backbones/hgnet_v2.py", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.List", "typing.Optional", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "conv.Conv", "conv.DWConv", "conv.GhostConv", "conv.LightConv", "conv.RepConv", "conv.autopad", "transformer.TransformerBlock", "torchvision", "nn.Module" ], "chunk_id": "class_HGBlock_934eb761" }, { "content": "class SPP(nn.Module):\n \"\"\"Spatial Pyramid Pooling (SPP) layer https://arxiv.org/abs/1406.4729.\"\"\"\n\n def __init__(self, c1: int, c2: int, k: Tuple[int, ...] = (5, 9, 13)):\n \"\"\"\n Initialize the SPP layer with input/output channels and pooling kernel sizes.\n\n Args:\n c1 (int): Input channels.\n c2 (int): Output channels.\n k (tuple): Kernel sizes for max pooling.\n \"\"\"\n super().__init__()\n c_ = c1 // 2 # hidden channels\n self.cv1 = Conv(c1, c_, 1, 1)\n self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1)\n self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Forward pass of the SPP layer, performing spatial pyramid pooling.\"\"\"\n x = self.cv1(x)\n return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))", "chunk_type": "class", "name": "SPP", "file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py", "start_line": 192, "end_line": 213, "start_col": 0, "end_col": 67, "parent_name": null, "docstring": "Spatial Pyramid Pooling (SPP) layer https://arxiv.org/abs/1406.4729.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.List", "typing.Optional", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "conv.Conv", "conv.DWConv", "conv.GhostConv", "conv.LightConv", "conv.RepConv", "conv.autopad", "transformer.TransformerBlock", "torchvision", "nn.Module" ], "chunk_id": "class_SPP_936a261d" }, { "content": "class SPPF(nn.Module):\n \"\"\"Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher.\"\"\"\n\n def __init__(self, c1: int, c2: int, k: int = 5):\n \"\"\"\n Initialize the SPPF layer with given input/output channels and kernel size.\n\n Args:\n c1 (int): Input channels.\n c2 (int): Output channels.\n k (int): Kernel size.\n\n Notes:\n This module is equivalent to SPP(k=(5, 9, 13)).\n \"\"\"\n super().__init__()\n c_ = c1 // 2 # hidden channels\n self.cv1 = Conv(c1, c_, 1, 1)\n self.cv2 = Conv(c_ * 4, c2, 1, 1)\n self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Apply sequential pooling operations to input and return concatenated feature maps.\"\"\"\n y = [self.cv1(x)]\n y.extend(self.m(y[-1]) for _ in range(3))\n return self.cv2(torch.cat(y, 1))", "chunk_type": "class", "name": "SPPF", "file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py", "start_line": 216, "end_line": 241, "start_col": 0, "end_col": 40, "parent_name": null, "docstring": "Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.List", "typing.Optional", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "conv.Conv", "conv.DWConv", "conv.GhostConv", "conv.LightConv", "conv.RepConv", "conv.autopad", "transformer.TransformerBlock", "torchvision", "nn.Module" ], "chunk_id": "class_SPPF_4ada2457" }, { "content": "class C1(nn.Module):\n \"\"\"CSP Bottleneck with 1 convolution.\"\"\"\n\n def __init__(self, c1: int, c2: int, n: int = 1):\n \"\"\"\n Initialize the CSP Bottleneck with 1 convolution.\n\n Args:\n c1 (int): Input channels.\n c2 (int): Output channels.\n n (int): Number of convolutions.\n \"\"\"\n super().__init__()\n self.cv1 = Conv(c1, c2, 1, 1)\n self.m = nn.Sequential(*(Conv(c2, c2, 3) for _ in range(n)))\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Apply convolution and residual connection to input tensor.\"\"\"\n y = self.cv1(x)\n return self.m(y) + y", "chunk_type": "class", "name": "C1", "file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py", "start_line": 244, "end_line": 263, "start_col": 0, "end_col": 28, "parent_name": null, "docstring": "CSP Bottleneck with 1 convolution.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.List", "typing.Optional", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "conv.Conv", "conv.DWConv", "conv.GhostConv", "conv.LightConv", "conv.RepConv", "conv.autopad", "transformer.TransformerBlock", "torchvision", "nn.Module" ], "chunk_id": "class_C1_39b5aedd" }, { "content": "class C2(nn.Module):\n \"\"\"CSP Bottleneck with 2 convolutions.\"\"\"\n\n def __init__(self, c1: int, c2: int, n: int = 1, shortcut: bool = True, g: int = 1, e: float = 0.5):\n \"\"\"\n Initialize a CSP Bottleneck with 2 convolutions.\n\n Args:\n c1 (int): Input channels.\n c2 (int): Output channels.\n n (int): Number of Bottleneck blocks.\n shortcut (bool): Whether to use shortcut connections.\n g (int): Groups for convolutions.\n e (float): Expansion ratio.\n \"\"\"\n super().__init__()\n self.c = int(c2 * e) # hidden channels\n self.cv1 = Conv(c1, 2 * self.c, 1, 1)\n self.cv2 = Conv(2 * self.c, c2, 1) # optional act=FReLU(c2)\n # self.attention = ChannelAttention(2 * self.c) # or SpatialAttention()\n self.m = nn.Sequential(*(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n)))\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Forward pass through the CSP bottleneck with 2 convolutions.\"\"\"\n a, b = self.cv1(x).chunk(2, 1)\n return self.cv2(torch.cat((self.m(a), b), 1))", "chunk_type": "class", "name": "C2", "file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py", "start_line": 266, "end_line": 291, "start_col": 0, "end_col": 53, "parent_name": null, "docstring": "CSP Bottleneck with 2 convolutions.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.List", "typing.Optional", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "conv.Conv", "conv.DWConv", "conv.GhostConv", "conv.LightConv", "conv.RepConv", "conv.autopad", "transformer.TransformerBlock", "torchvision", "nn.Module" ], "chunk_id": "class_C2_cbb49698" }, { "content": "class C2f(nn.Module):\n \"\"\"Faster Implementation of CSP Bottleneck with 2 convolutions.\"\"\"\n\n def __init__(self, c1: int, c2: int, n: int = 1, shortcut: bool = False, g: int = 1, e: float = 0.5):\n \"\"\"\n Initialize a CSP bottleneck with 2 convolutions.\n\n Args:\n c1 (int): Input channels.\n c2 (int): Output channels.\n n (int): Number of Bottleneck blocks.\n shortcut (bool): Whether to use shortcut connections.\n g (int): Groups for convolutions.\n e (float): Expansion ratio.\n \"\"\"\n super().__init__()\n self.c = int(c2 * e) # hidden channels\n self.cv1 = Conv(c1, 2 * self.c, 1, 1)\n self.cv2 = Conv((2 + n) * self.c, c2, 1) # optional act=FReLU(c2)\n self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Forward pass through C2f layer.\"\"\"\n y = list(self.cv1(x).chunk(2, 1))\n y.extend(m(y[-1]) for m in self.m)\n return self.cv2(torch.cat(y, 1))\n\n def forward_split(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Forward pass using split() instead of chunk().\"\"\"\n y = self.cv1(x).split((self.c, self.c), 1)\n y = [y[0], y[1]]\n y.extend(m(y[-1]) for m in self.m)\n return self.cv2(torch.cat(y, 1))", "chunk_type": "class", "name": "C2f", "file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py", "start_line": 294, "end_line": 326, "start_col": 0, "end_col": 40, "parent_name": null, "docstring": "Faster Implementation of CSP Bottleneck with 2 convolutions.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.List", "typing.Optional", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "conv.Conv", "conv.DWConv", "conv.GhostConv", "conv.LightConv", "conv.RepConv", "conv.autopad", "transformer.TransformerBlock", "torchvision", "nn.Module" ], "chunk_id": "class_C2f_947106e9" }, { "content": "class C3(nn.Module):\n \"\"\"CSP Bottleneck with 3 convolutions.\"\"\"\n\n def __init__(self, c1: int, c2: int, n: int = 1, shortcut: bool = True, g: int = 1, e: float = 0.5):\n \"\"\"\n Initialize the CSP Bottleneck with 3 convolutions.\n\n Args:\n c1 (int): Input channels.\n c2 (int): Output channels.\n n (int): Number of Bottleneck blocks.\n shortcut (bool): Whether to use shortcut connections.\n g (int): Groups for convolutions.\n e (float): Expansion ratio.\n \"\"\"\n super().__init__()\n c_ = int(c2 * e) # hidden channels\n self.cv1 = Conv(c1, c_, 1, 1)\n self.cv2 = Conv(c1, c_, 1, 1)\n self.cv3 = Conv(2 * c_, c2, 1) # optional act=FReLU(c2)\n self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, k=((1, 1), (3, 3)), e=1.0) for _ in range(n)))\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Forward pass through the CSP bottleneck with 3 convolutions.\"\"\"\n return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))", "chunk_type": "class", "name": "C3", "file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py", "start_line": 329, "end_line": 353, "start_col": 0, "end_col": 73, "parent_name": null, "docstring": "CSP Bottleneck with 3 convolutions.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.List", "typing.Optional", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "conv.Conv", "conv.DWConv", "conv.GhostConv", "conv.LightConv", "conv.RepConv", "conv.autopad", "transformer.TransformerBlock", "torchvision", "nn.Module" ], "chunk_id": "class_C3_eccd69e4" }, { "content": "class C3x(C3):\n \"\"\"C3 module with cross-convolutions.\"\"\"\n\n def __init__(self, c1: int, c2: int, n: int = 1, shortcut: bool = True, g: int = 1, e: float = 0.5):\n \"\"\"\n Initialize C3 module with cross-convolutions.\n\n Args:\n c1 (int): Input channels.\n c2 (int): Output channels.\n n (int): Number of Bottleneck blocks.\n shortcut (bool): Whether to use shortcut connections.\n g (int): Groups for convolutions.\n e (float): Expansion ratio.\n \"\"\"\n super().__init__(c1, c2, n, shortcut, g, e)\n self.c_ = int(c2 * e)\n self.m = nn.Sequential(*(Bottleneck(self.c_, self.c_, shortcut, g, k=((1, 3), (3, 1)), e=1) for _ in range(n)))", "chunk_type": "class", "name": "C3x", "file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py", "start_line": 356, "end_line": 373, "start_col": 0, "end_col": 119, "parent_name": null, "docstring": "C3 module with cross-convolutions.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.List", "typing.Optional", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "conv.Conv", "conv.DWConv", "conv.GhostConv", "conv.LightConv", "conv.RepConv", "conv.autopad", "transformer.TransformerBlock", "torchvision", "C3" ], "chunk_id": "class_C3x_a86a4bf9" }, { "content": "class RepC3(nn.Module):\n \"\"\"Rep C3.\"\"\"\n\n def __init__(self, c1: int, c2: int, n: int = 3, e: float = 1.0):\n \"\"\"\n Initialize CSP Bottleneck with a single convolution.\n\n Args:\n c1 (int): Input channels.\n c2 (int): Output channels.\n n (int): Number of RepConv blocks.\n e (float): Expansion ratio.\n \"\"\"\n super().__init__()\n c_ = int(c2 * e) # hidden channels\n self.cv1 = Conv(c1, c_, 1, 1)\n self.cv2 = Conv(c1, c_, 1, 1)\n self.m = nn.Sequential(*[RepConv(c_, c_) for _ in range(n)])\n self.cv3 = Conv(c_, c2, 1, 1) if c_ != c2 else nn.Identity()\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Forward pass of RepC3 module.\"\"\"\n return self.cv3(self.m(self.cv1(x)) + self.cv2(x))", "chunk_type": "class", "name": "RepC3", "file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py", "start_line": 376, "end_line": 398, "start_col": 0, "end_col": 58, "parent_name": null, "docstring": "Rep C3.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.List", "typing.Optional", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "conv.Conv", "conv.DWConv", "conv.GhostConv", "conv.LightConv", "conv.RepConv", "conv.autopad", "transformer.TransformerBlock", "torchvision", "nn.Module" ], "chunk_id": "class_RepC3_a0dfcb95" }, { "content": "class C3TR(C3):\n \"\"\"C3 module with TransformerBlock().\"\"\"\n\n def __init__(self, c1: int, c2: int, n: int = 1, shortcut: bool = True, g: int = 1, e: float = 0.5):\n \"\"\"\n Initialize C3 module with TransformerBlock.\n\n Args:\n c1 (int): Input channels.\n c2 (int): Output channels.\n n (int): Number of Transformer blocks.\n shortcut (bool): Whether to use shortcut connections.\n g (int): Groups for convolutions.\n e (float): Expansion ratio.\n \"\"\"\n super().__init__(c1, c2, n, shortcut, g, e)\n c_ = int(c2 * e)\n self.m = TransformerBlock(c_, c_, 4, n)", "chunk_type": "class", "name": "C3TR", "file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py", "start_line": 401, "end_line": 418, "start_col": 0, "end_col": 47, "parent_name": null, "docstring": "C3 module with TransformerBlock().", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.List", "typing.Optional", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "conv.Conv", "conv.DWConv", "conv.GhostConv", "conv.LightConv", "conv.RepConv", "conv.autopad", "transformer.TransformerBlock", "torchvision", "C3" ], "chunk_id": "class_C3TR_fe347d69" }, { "content": "class C3Ghost(C3):\n \"\"\"C3 module with GhostBottleneck().\"\"\"\n\n def __init__(self, c1: int, c2: int, n: int = 1, shortcut: bool = True, g: int = 1, e: float = 0.5):\n \"\"\"\n Initialize C3 module with GhostBottleneck.\n\n Args:\n c1 (int): Input channels.\n c2 (int): Output channels.\n n (int): Number of Ghost bottleneck blocks.\n shortcut (bool): Whether to use shortcut connections.\n g (int): Groups for convolutions.\n e (float): Expansion ratio.\n \"\"\"\n super().__init__(c1, c2, n, shortcut, g, e)\n c_ = int(c2 * e) # hidden channels\n self.m = nn.Sequential(*(GhostBottleneck(c_, c_) for _ in range(n)))", "chunk_type": "class", "name": "C3Ghost", "file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py", "start_line": 421, "end_line": 438, "start_col": 0, "end_col": 76, "parent_name": null, "docstring": "C3 module with GhostBottleneck().", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.List", "typing.Optional", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "conv.Conv", "conv.DWConv", "conv.GhostConv", "conv.LightConv", "conv.RepConv", "conv.autopad", "transformer.TransformerBlock", "torchvision", "C3" ], "chunk_id": "class_C3Ghost_42be1069" }, { "content": "class GhostBottleneck(nn.Module):\n \"\"\"Ghost Bottleneck https://github.com/huawei-noah/Efficient-AI-Backbones.\"\"\"\n\n def __init__(self, c1: int, c2: int, k: int = 3, s: int = 1):\n \"\"\"\n Initialize Ghost Bottleneck module.\n\n Args:\n c1 (int): Input channels.\n c2 (int): Output channels.\n k (int): Kernel size.\n s (int): Stride.\n \"\"\"\n super().__init__()\n c_ = c2 // 2\n self.conv = nn.Sequential(\n GhostConv(c1, c_, 1, 1), # pw\n DWConv(c_, c_, k, s, act=False) if s == 2 else nn.Identity(), # dw\n GhostConv(c_, c2, 1, 1, act=False), # pw-linear\n )\n self.shortcut = (\n nn.Sequential(DWConv(c1, c1, k, s, act=False), Conv(c1, c2, 1, 1, act=False)) if s == 2 else nn.Identity()\n )\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Apply skip connection and concatenation to input tensor.\"\"\"\n return self.conv(x) + self.shortcut(x)", "chunk_type": "class", "name": "GhostBottleneck", "file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py", "start_line": 441, "end_line": 467, "start_col": 0, "end_col": 46, "parent_name": null, "docstring": "Ghost Bottleneck https://github.com/huawei-noah/Efficient-AI-Backbones.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.List", "typing.Optional", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "conv.Conv", "conv.DWConv", "conv.GhostConv", "conv.LightConv", "conv.RepConv", "conv.autopad", "transformer.TransformerBlock", "torchvision", "nn.Module" ], "chunk_id": "class_GhostBottleneck_87d94d85" }, { "content": "class Bottleneck(nn.Module):\n \"\"\"Standard bottleneck.\"\"\"\n\n def __init__(\n self, c1: int, c2: int, shortcut: bool = True, g: int = 1, k: Tuple[int, int] = (3, 3), e: float = 0.5\n ):\n \"\"\"\n Initialize a standard bottleneck module.\n\n Args:\n c1 (int): Input channels.\n c2 (int): Output channels.\n shortcut (bool): Whether to use shortcut connection.\n g (int): Groups for convolutions.\n k (tuple): Kernel sizes for convolutions.\n e (float): Expansion ratio.\n \"\"\"\n super().__init__()\n c_ = int(c2 * e) # hidden channels\n self.cv1 = Conv(c1, c_, k[0], 1)\n self.cv2 = Conv(c_, c2, k[1], 1, g=g)\n self.add = shortcut and c1 == c2\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Apply bottleneck with optional shortcut connection.\"\"\"\n return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))", "chunk_type": "class", "name": "Bottleneck", "file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py", "start_line": 470, "end_line": 495, "start_col": 0, "end_col": 79, "parent_name": null, "docstring": "Standard bottleneck.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.List", "typing.Optional", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "conv.Conv", "conv.DWConv", "conv.GhostConv", "conv.LightConv", "conv.RepConv", "conv.autopad", "transformer.TransformerBlock", "torchvision", "nn.Module" ], "chunk_id": "class_Bottleneck_b6c60a65" }, { "content": "class BottleneckCSP(nn.Module):\n \"\"\"CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks.\"\"\"\n\n def __init__(self, c1: int, c2: int, n: int = 1, shortcut: bool = True, g: int = 1, e: float = 0.5):\n \"\"\"\n Initialize CSP Bottleneck.\n\n Args:\n c1 (int): Input channels.\n c2 (int): Output channels.\n n (int): Number of Bottleneck blocks.\n shortcut (bool): Whether to use shortcut connections.\n g (int): Groups for convolutions.\n e (float): Expansion ratio.\n \"\"\"\n super().__init__()\n c_ = int(c2 * e) # hidden channels\n self.cv1 = Conv(c1, c_, 1, 1)\n self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False)\n self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False)\n self.cv4 = Conv(2 * c_, c2, 1, 1)\n self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3)\n self.act = nn.SiLU()\n self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Apply CSP bottleneck with 3 convolutions.\"\"\"\n y1 = self.cv3(self.m(self.cv1(x)))\n y2 = self.cv2(x)\n return self.cv4(self.act(self.bn(torch.cat((y1, y2), 1))))", "chunk_type": "class", "name": "BottleneckCSP", "file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py", "start_line": 498, "end_line": 527, "start_col": 0, "end_col": 66, "parent_name": null, "docstring": "CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.List", "typing.Optional", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "conv.Conv", "conv.DWConv", "conv.GhostConv", "conv.LightConv", "conv.RepConv", "conv.autopad", "transformer.TransformerBlock", "torchvision", "nn.Module" ], "chunk_id": "class_BottleneckCSP_d85467ed" }, { "content": "class ResNetBlock(nn.Module):\n \"\"\"ResNet block with standard convolution layers.\"\"\"\n\n def __init__(self, c1: int, c2: int, s: int = 1, e: int = 4):\n \"\"\"\n Initialize ResNet block.\n\n Args:\n c1 (int): Input channels.\n c2 (int): Output channels.\n s (int): Stride.\n e (int): Expansion ratio.\n \"\"\"\n super().__init__()\n c3 = e * c2\n self.cv1 = Conv(c1, c2, k=1, s=1, act=True)\n self.cv2 = Conv(c2, c2, k=3, s=s, p=1, act=True)\n self.cv3 = Conv(c2, c3, k=1, act=False)\n self.shortcut = nn.Sequential(Conv(c1, c3, k=1, s=s, act=False)) if s != 1 or c1 != c3 else nn.Identity()\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Forward pass through the ResNet block.\"\"\"\n return F.relu(self.cv3(self.cv2(self.cv1(x))) + self.shortcut(x))", "chunk_type": "class", "name": "ResNetBlock", "file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py", "start_line": 530, "end_line": 552, "start_col": 0, "end_col": 73, "parent_name": null, "docstring": "ResNet block with standard convolution layers.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.List", "typing.Optional", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "conv.Conv", "conv.DWConv", "conv.GhostConv", "conv.LightConv", "conv.RepConv", "conv.autopad", "transformer.TransformerBlock", "torchvision", "nn.Module" ], "chunk_id": "class_ResNetBlock_0bc30168" }, { "content": "class ResNetLayer(nn.Module):\n \"\"\"ResNet layer with multiple ResNet blocks.\"\"\"\n\n def __init__(self, c1: int, c2: int, s: int = 1, is_first: bool = False, n: int = 1, e: int = 4):\n \"\"\"\n Initialize ResNet layer.\n\n Args:\n c1 (int): Input channels.\n c2 (int): Output channels.\n s (int): Stride.\n is_first (bool): Whether this is the first layer.\n n (int): Number of ResNet blocks.\n e (int): Expansion ratio.\n \"\"\"\n super().__init__()\n self.is_first = is_first\n\n if self.is_first:\n self.layer = nn.Sequential(\n Conv(c1, c2, k=7, s=2, p=3, act=True), nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n )\n else:\n blocks = [ResNetBlock(c1, c2, s, e=e)]\n blocks.extend([ResNetBlock(e * c2, c2, 1, e=e) for _ in range(n - 1)])\n self.layer = nn.Sequential(*blocks)\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Forward pass through the ResNet layer.\"\"\"\n return self.layer(x)", "chunk_type": "class", "name": "ResNetLayer", "file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py", "start_line": 555, "end_line": 584, "start_col": 0, "end_col": 28, "parent_name": null, "docstring": "ResNet layer with multiple ResNet blocks.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.List", "typing.Optional", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "conv.Conv", "conv.DWConv", "conv.GhostConv", "conv.LightConv", "conv.RepConv", "conv.autopad", "transformer.TransformerBlock", "torchvision", "nn.Module" ], "chunk_id": "class_ResNetLayer_07394af0" }, { "content": "class MaxSigmoidAttnBlock(nn.Module):\n \"\"\"Max Sigmoid attention block.\"\"\"\n\n def __init__(self, c1: int, c2: int, nh: int = 1, ec: int = 128, gc: int = 512, scale: bool = False):\n \"\"\"\n Initialize MaxSigmoidAttnBlock.\n\n Args:\n c1 (int): Input channels.\n c2 (int): Output channels.\n nh (int): Number of heads.\n ec (int): Embedding channels.\n gc (int): Guide channels.\n scale (bool): Whether to use learnable scale parameter.\n \"\"\"\n super().__init__()\n self.nh = nh\n self.hc = c2 // nh\n self.ec = Conv(c1, ec, k=1, act=False) if c1 != ec else None\n self.gl = nn.Linear(gc, ec)\n self.bias = nn.Parameter(torch.zeros(nh))\n self.proj_conv = Conv(c1, c2, k=3, s=1, act=False)\n self.scale = nn.Parameter(torch.ones(1, nh, 1, 1)) if scale else 1.0\n\n def forward(self, x: torch.Tensor, guide: torch.Tensor) -> torch.Tensor:\n \"\"\"\n Forward pass of MaxSigmoidAttnBlock.\n\n Args:\n x (torch.Tensor): Input tensor.\n guide (torch.Tensor): Guide tensor.\n\n Returns:\n (torch.Tensor): Output tensor after attention.\n \"\"\"\n bs, _, h, w = x.shape\n\n guide = self.gl(guide)\n guide = guide.view(bs, guide.shape[1], self.nh, self.hc)\n embed = self.ec(x) if self.ec is not None else x\n embed = embed.view(bs, self.nh, self.hc, h, w)\n\n aw = torch.einsum(\"bmchw,bnmc->bmhwn\", embed, guide)\n aw = aw.max(dim=-1)[0]\n aw = aw / (self.hc**0.5)\n aw = aw + self.bias[None, :, None, None]\n aw = aw.sigmoid() * self.scale\n\n x = self.proj_conv(x)\n x = x.view(bs, self.nh, -1, h, w)\n x = x * aw.unsqueeze(2)\n return x.view(bs, -1, h, w)", "chunk_type": "class", "name": "MaxSigmoidAttnBlock", "file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py", "start_line": 587, "end_line": 638, "start_col": 0, "end_col": 35, "parent_name": null, "docstring": "Max Sigmoid attention block.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.List", "typing.Optional", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "conv.Conv", "conv.DWConv", "conv.GhostConv", "conv.LightConv", "conv.RepConv", "conv.autopad", "transformer.TransformerBlock", "torchvision", "nn.Module" ], "chunk_id": "class_MaxSigmoidAttnBlock_e6b84d8e" }, { "content": "class C2fAttn(nn.Module):\n \"\"\"C2f module with an additional attn module.\"\"\"\n\n def __init__(\n self,\n c1: int,\n c2: int,\n n: int = 1,\n ec: int = 128,\n nh: int = 1,\n gc: int = 512,\n shortcut: bool = False,\n g: int = 1,\n e: float = 0.5,\n ):\n \"\"\"\n Initialize C2f module with attention mechanism.\n\n Args:\n c1 (int): Input channels.\n c2 (int): Output channels.\n n (int): Number of Bottleneck blocks.\n ec (int): Embedding channels for attention.\n nh (int): Number of heads for attention.\n gc (int): Guide channels for attention.\n shortcut (bool): Whether to use shortcut connections.\n g (int): Groups for convolutions.\n e (float): Expansion ratio.\n \"\"\"\n super().__init__()\n self.c = int(c2 * e) # hidden channels\n self.cv1 = Conv(c1, 2 * self.c, 1, 1)\n self.cv2 = Conv((3 + n) * self.c, c2, 1) # optional act=FReLU(c2)\n self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))\n self.attn = MaxSigmoidAttnBlock(self.c, self.c, gc=gc, ec=ec, nh=nh)\n\n def forward(self, x: torch.Tensor, guide: torch.Tensor) -> torch.Tensor:\n \"\"\"\n Forward pass through C2f layer with attention.\n\n Args:\n x (torch.Tensor): Input tensor.\n guide (torch.Tensor): Guide tensor for attention.\n\n Returns:\n (torch.Tensor): Output tensor after processing.\n \"\"\"\n y = list(self.cv1(x).chunk(2, 1))\n y.extend(m(y[-1]) for m in self.m)\n y.append(self.attn(y[-1], guide))\n return self.cv2(torch.cat(y, 1))\n\n def forward_split(self, x: torch.Tensor, guide: torch.Tensor) -> torch.Tensor:\n \"\"\"\n Forward pass using split() instead of chunk().\n\n Args:\n x (torch.Tensor): Input tensor.\n guide (torch.Tensor): Guide tensor for attention.\n\n Returns:\n (torch.Tensor): Output tensor after processing.\n \"\"\"\n y = list(self.cv1(x).split((self.c, self.c), 1))\n y.extend(m(y[-1]) for m in self.m)\n y.append(self.attn(y[-1], guide))\n return self.cv2(torch.cat(y, 1))", "chunk_type": "class", "name": "C2fAttn", "file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py", "start_line": 641, "end_line": 707, "start_col": 0, "end_col": 40, "parent_name": null, "docstring": "C2f module with an additional attn module.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.List", "typing.Optional", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "conv.Conv", "conv.DWConv", "conv.GhostConv", "conv.LightConv", "conv.RepConv", "conv.autopad", "transformer.TransformerBlock", "torchvision", "nn.Module" ], "chunk_id": "class_C2fAttn_77c66cdf" }, { "content": "class ImagePoolingAttn(nn.Module):\n \"\"\"ImagePoolingAttn: Enhance the text embeddings with image-aware information.\"\"\"\n\n def __init__(\n self, ec: int = 256, ch: Tuple[int, ...] = (), ct: int = 512, nh: int = 8, k: int = 3, scale: bool = False\n ):\n \"\"\"\n Initialize ImagePoolingAttn module.\n\n Args:\n ec (int): Embedding channels.\n ch (tuple): Channel dimensions for feature maps.\n ct (int): Channel dimension for text embeddings.\n nh (int): Number of attention heads.\n k (int): Kernel size for pooling.\n scale (bool): Whether to use learnable scale parameter.\n \"\"\"\n super().__init__()\n\n nf = len(ch)\n self.query = nn.Sequential(nn.LayerNorm(ct), nn.Linear(ct, ec))\n self.key = nn.Sequential(nn.LayerNorm(ec), nn.Linear(ec, ec))\n self.value = nn.Sequential(nn.LayerNorm(ec), nn.Linear(ec, ec))\n self.proj = nn.Linear(ec, ct)\n self.scale = nn.Parameter(torch.tensor([0.0]), requires_grad=True) if scale else 1.0\n self.projections = nn.ModuleList([nn.Conv2d(in_channels, ec, kernel_size=1) for in_channels in ch])\n self.im_pools = nn.ModuleList([nn.AdaptiveMaxPool2d((k, k)) for _ in range(nf)])\n self.ec = ec\n self.nh = nh\n self.nf = nf\n self.hc = ec // nh\n self.k = k\n\n def forward(self, x: List[torch.Tensor], text: torch.Tensor) -> torch.Tensor:\n \"\"\"\n Forward pass of ImagePoolingAttn.\n\n Args:\n x (List[torch.Tensor]): List of input feature maps.\n text (torch.Tensor): Text embeddings.\n\n Returns:\n (torch.Tensor): Enhanced text embeddings.\n \"\"\"\n bs = x[0].shape[0]\n assert len(x) == self.nf\n num_patches = self.k**2\n x = [pool(proj(x)).view(bs, -1, num_patches) for (x, proj, pool) in zip(x, self.projections, self.im_pools)]\n x = torch.cat(x, dim=-1).transpose(1, 2)\n q = self.query(text)\n k = self.key(x)\n v = self.value(x)\n\n # q = q.reshape(1, text.shape[1], self.nh, self.hc).repeat(bs, 1, 1, 1)\n q = q.reshape(bs, -1, self.nh, self.hc)\n k = k.reshape(bs, -1, self.nh, self.hc)\n v = v.reshape(bs, -1, self.nh, self.hc)\n\n aw = torch.einsum(\"bnmc,bkmc->bmnk\", q, k)\n aw = aw / (self.hc**0.5)\n aw = F.softmax(aw, dim=-1)\n\n x = torch.einsum(\"bmnk,bkmc->bnmc\", aw, v)\n x = self.proj(x.reshape(bs, -1, self.ec))\n return x * self.scale + text", "chunk_type": "class", "name": "ImagePoolingAttn", "file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py", "start_line": 710, "end_line": 774, "start_col": 0, "end_col": 36, "parent_name": null, "docstring": "ImagePoolingAttn: Enhance the text embeddings with image-aware information.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.List", "typing.Optional", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "conv.Conv", "conv.DWConv", "conv.GhostConv", "conv.LightConv", "conv.RepConv", "conv.autopad", "transformer.TransformerBlock", "torchvision", "nn.Module" ], "chunk_id": "class_ImagePoolingAttn_7cec88fc" }, { "content": "class ContrastiveHead(nn.Module):\n \"\"\"Implements contrastive learning head for region-text similarity in vision-language models.\"\"\"\n\n def __init__(self):\n \"\"\"Initialize ContrastiveHead with region-text similarity parameters.\"\"\"\n super().__init__()\n # NOTE: use -10.0 to keep the init cls loss consistency with other losses\n self.bias = nn.Parameter(torch.tensor([-10.0]))\n self.logit_scale = nn.Parameter(torch.ones([]) * torch.tensor(1 / 0.07).log())\n\n def forward(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:\n \"\"\"\n Forward function of contrastive learning.\n\n Args:\n x (torch.Tensor): Image features.\n w (torch.Tensor): Text features.\n\n Returns:\n (torch.Tensor): Similarity scores.\n \"\"\"\n x = F.normalize(x, dim=1, p=2)\n w = F.normalize(w, dim=-1, p=2)\n x = torch.einsum(\"bchw,bkc->bkhw\", x, w)\n return x * self.logit_scale.exp() + self.bias", "chunk_type": "class", "name": "ContrastiveHead", "file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py", "start_line": 777, "end_line": 801, "start_col": 0, "end_col": 53, "parent_name": null, "docstring": "Implements contrastive learning head for region-text similarity in vision-language models.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.List", "typing.Optional", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "conv.Conv", "conv.DWConv", "conv.GhostConv", "conv.LightConv", "conv.RepConv", "conv.autopad", "transformer.TransformerBlock", "torchvision", "nn.Module" ], "chunk_id": "class_ContrastiveHead_c63f9e5c" }, { "content": "class BNContrastiveHead(nn.Module):\n \"\"\"\n Batch Norm Contrastive Head using batch norm instead of l2-normalization.\n\n Args:\n embed_dims (int): Embed dimensions of text and image features.\n \"\"\"\n\n def __init__(self, embed_dims: int):\n \"\"\"\n Initialize BNContrastiveHead.\n\n Args:\n embed_dims (int): Embedding dimensions for features.\n \"\"\"\n super().__init__()\n self.norm = nn.BatchNorm2d(embed_dims)\n # NOTE: use -10.0 to keep the init cls loss consistency with other losses\n self.bias = nn.Parameter(torch.tensor([-10.0]))\n # use -1.0 is more stable\n self.logit_scale = nn.Parameter(-1.0 * torch.ones([]))\n\n def fuse(self):\n \"\"\"Fuse the batch normalization layer in the BNContrastiveHead module.\"\"\"\n del self.norm\n del self.bias\n del self.logit_scale\n self.forward = self.forward_fuse\n\n def forward_fuse(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:\n \"\"\"Passes input out unchanged.\"\"\"\n return x\n\n def forward(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:\n \"\"\"\n Forward function of contrastive learning with batch normalization.\n\n Args:\n x (torch.Tensor): Image features.\n w (torch.Tensor): Text features.\n\n Returns:\n (torch.Tensor): Similarity scores.\n \"\"\"\n x = self.norm(x)\n w = F.normalize(w, dim=-1, p=2)\n\n x = torch.einsum(\"bchw,bkc->bkhw\", x, w)\n return x * self.logit_scale.exp() + self.bias", "chunk_type": "class", "name": "BNContrastiveHead", "file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py", "start_line": 804, "end_line": 852, "start_col": 0, "end_col": 53, "parent_name": null, "docstring": "Batch Norm Contrastive Head using batch norm instead of l2-normalization.\n\nArgs:\n embed_dims (int): Embed dimensions of text and image features.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.List", "typing.Optional", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "conv.Conv", "conv.DWConv", "conv.GhostConv", "conv.LightConv", "conv.RepConv", "conv.autopad", "transformer.TransformerBlock", "torchvision", "nn.Module" ], "chunk_id": "class_BNContrastiveHead_523689c4" }, { "content": "class RepBottleneck(Bottleneck):\n \"\"\"Rep bottleneck.\"\"\"\n\n def __init__(\n self, c1: int, c2: int, shortcut: bool = True, g: int = 1, k: Tuple[int, int] = (3, 3), e: float = 0.5\n ):\n \"\"\"\n Initialize RepBottleneck.\n\n Args:\n c1 (int): Input channels.\n c2 (int): Output channels.\n shortcut (bool): Whether to use shortcut connection.\n g (int): Groups for convolutions.\n k (tuple): Kernel sizes for convolutions.\n e (float): Expansion ratio.\n \"\"\"\n super().__init__(c1, c2, shortcut, g, k, e)\n c_ = int(c2 * e) # hidden channels\n self.cv1 = RepConv(c1, c_, k[0], 1)", "chunk_type": "class", "name": "RepBottleneck", "file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py", "start_line": 855, "end_line": 874, "start_col": 0, "end_col": 43, "parent_name": null, "docstring": "Rep bottleneck.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.List", "typing.Optional", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "conv.Conv", "conv.DWConv", "conv.GhostConv", "conv.LightConv", "conv.RepConv", "conv.autopad", "transformer.TransformerBlock", "torchvision", "Bottleneck" ], "chunk_id": "class_RepBottleneck_872e277d" }, { "content": "class RepCSP(C3):\n \"\"\"Repeatable Cross Stage Partial Network (RepCSP) module for efficient feature extraction.\"\"\"\n\n def __init__(self, c1: int, c2: int, n: int = 1, shortcut: bool = True, g: int = 1, e: float = 0.5):\n \"\"\"\n Initialize RepCSP layer.\n\n Args:\n c1 (int): Input channels.\n c2 (int): Output channels.\n n (int): Number of RepBottleneck blocks.\n shortcut (bool): Whether to use shortcut connections.\n g (int): Groups for convolutions.\n e (float): Expansion ratio.\n \"\"\"\n super().__init__(c1, c2, n, shortcut, g, e)\n c_ = int(c2 * e) # hidden channels\n self.m = nn.Sequential(*(RepBottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))", "chunk_type": "class", "name": "RepCSP", "file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py", "start_line": 877, "end_line": 894, "start_col": 0, "end_col": 94, "parent_name": null, "docstring": "Repeatable Cross Stage Partial Network (RepCSP) module for efficient feature extraction.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.List", "typing.Optional", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "conv.Conv", "conv.DWConv", "conv.GhostConv", "conv.LightConv", "conv.RepConv", "conv.autopad", "transformer.TransformerBlock", "torchvision", "C3" ], "chunk_id": "class_RepCSP_44069c52" }, { "content": "class RepNCSPELAN4(nn.Module):\n \"\"\"CSP-ELAN.\"\"\"\n\n def __init__(self, c1: int, c2: int, c3: int, c4: int, n: int = 1):\n \"\"\"\n Initialize CSP-ELAN layer.\n\n Args:\n c1 (int): Input channels.\n c2 (int): Output channels.\n c3 (int): Intermediate channels.\n c4 (int): Intermediate channels for RepCSP.\n n (int): Number of RepCSP blocks.\n \"\"\"\n super().__init__()\n self.c = c3 // 2\n self.cv1 = Conv(c1, c3, 1, 1)\n self.cv2 = nn.Sequential(RepCSP(c3 // 2, c4, n), Conv(c4, c4, 3, 1))\n self.cv3 = nn.Sequential(RepCSP(c4, c4, n), Conv(c4, c4, 3, 1))\n self.cv4 = Conv(c3 + (2 * c4), c2, 1, 1)\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Forward pass through RepNCSPELAN4 layer.\"\"\"\n y = list(self.cv1(x).chunk(2, 1))\n y.extend((m(y[-1])) for m in [self.cv2, self.cv3])\n return self.cv4(torch.cat(y, 1))\n\n def forward_split(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Forward pass using split() instead of chunk().\"\"\"\n y = list(self.cv1(x).split((self.c, self.c), 1))\n y.extend(m(y[-1]) for m in [self.cv2, self.cv3])\n return self.cv4(torch.cat(y, 1))", "chunk_type": "class", "name": "RepNCSPELAN4", "file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py", "start_line": 897, "end_line": 928, "start_col": 0, "end_col": 40, "parent_name": null, "docstring": "CSP-ELAN.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.List", "typing.Optional", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "conv.Conv", "conv.DWConv", "conv.GhostConv", "conv.LightConv", "conv.RepConv", "conv.autopad", "transformer.TransformerBlock", "torchvision", "nn.Module" ], "chunk_id": "class_RepNCSPELAN4_6f8bb196" }, { "content": "class ELAN1(RepNCSPELAN4):\n \"\"\"ELAN1 module with 4 convolutions.\"\"\"\n\n def __init__(self, c1: int, c2: int, c3: int, c4: int):\n \"\"\"\n Initialize ELAN1 layer.\n\n Args:\n c1 (int): Input channels.\n c2 (int): Output channels.\n c3 (int): Intermediate channels.\n c4 (int): Intermediate channels for convolutions.\n \"\"\"\n super().__init__(c1, c2, c3, c4)\n self.c = c3 // 2\n self.cv1 = Conv(c1, c3, 1, 1)\n self.cv2 = Conv(c3 // 2, c4, 3, 1)\n self.cv3 = Conv(c4, c4, 3, 1)\n self.cv4 = Conv(c3 + (2 * c4), c2, 1, 1)", "chunk_type": "class", "name": "ELAN1", "file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py", "start_line": 931, "end_line": 949, "start_col": 0, "end_col": 48, "parent_name": null, "docstring": "ELAN1 module with 4 convolutions.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.List", "typing.Optional", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "conv.Conv", "conv.DWConv", "conv.GhostConv", "conv.LightConv", "conv.RepConv", "conv.autopad", "transformer.TransformerBlock", "torchvision", "RepNCSPELAN4" ], "chunk_id": "class_ELAN1_988f4e6d" }, { "content": "class AConv(nn.Module):\n \"\"\"AConv.\"\"\"\n\n def __init__(self, c1: int, c2: int):\n \"\"\"\n Initialize AConv module.\n\n Args:\n c1 (int): Input channels.\n c2 (int): Output channels.\n \"\"\"\n super().__init__()\n self.cv1 = Conv(c1, c2, 3, 2, 1)\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Forward pass through AConv layer.\"\"\"\n x = torch.nn.functional.avg_pool2d(x, 2, 1, 0, False, True)\n return self.cv1(x)", "chunk_type": "class", "name": "AConv", "file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py", "start_line": 952, "end_line": 969, "start_col": 0, "end_col": 26, "parent_name": null, "docstring": "AConv.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.List", "typing.Optional", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "conv.Conv", "conv.DWConv", "conv.GhostConv", "conv.LightConv", "conv.RepConv", "conv.autopad", "transformer.TransformerBlock", "torchvision", "nn.Module" ], "chunk_id": "class_AConv_6ac2b074" }, { "content": "class ADown(nn.Module):\n \"\"\"ADown.\"\"\"\n\n def __init__(self, c1: int, c2: int):\n \"\"\"\n Initialize ADown module.\n\n Args:\n c1 (int): Input channels.\n c2 (int): Output channels.\n \"\"\"\n super().__init__()\n self.c = c2 // 2\n self.cv1 = Conv(c1 // 2, self.c, 3, 2, 1)\n self.cv2 = Conv(c1 // 2, self.c, 1, 1, 0)\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Forward pass through ADown layer.\"\"\"\n x = torch.nn.functional.avg_pool2d(x, 2, 1, 0, False, True)\n x1, x2 = x.chunk(2, 1)\n x1 = self.cv1(x1)\n x2 = torch.nn.functional.max_pool2d(x2, 3, 2, 1)\n x2 = self.cv2(x2)\n return torch.cat((x1, x2), 1)", "chunk_type": "class", "name": "ADown", "file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py", "start_line": 972, "end_line": 995, "start_col": 0, "end_col": 37, "parent_name": null, "docstring": "ADown.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.List", "typing.Optional", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "conv.Conv", "conv.DWConv", "conv.GhostConv", "conv.LightConv", "conv.RepConv", "conv.autopad", "transformer.TransformerBlock", "torchvision", "nn.Module" ], "chunk_id": "class_ADown_d418e88d" }, { "content": "class SPPELAN(nn.Module):\n \"\"\"SPP-ELAN.\"\"\"\n\n def __init__(self, c1: int, c2: int, c3: int, k: int = 5):\n \"\"\"\n Initialize SPP-ELAN block.\n\n Args:\n c1 (int): Input channels.\n c2 (int): Output channels.\n c3 (int): Intermediate channels.\n k (int): Kernel size for max pooling.\n \"\"\"\n super().__init__()\n self.c = c3\n self.cv1 = Conv(c1, c3, 1, 1)\n self.cv2 = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)\n self.cv3 = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)\n self.cv4 = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)\n self.cv5 = Conv(4 * c3, c2, 1, 1)\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Forward pass through SPPELAN layer.\"\"\"\n y = [self.cv1(x)]\n y.extend(m(y[-1]) for m in [self.cv2, self.cv3, self.cv4])\n return self.cv5(torch.cat(y, 1))", "chunk_type": "class", "name": "SPPELAN", "file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py", "start_line": 998, "end_line": 1023, "start_col": 0, "end_col": 40, "parent_name": null, "docstring": "SPP-ELAN.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.List", "typing.Optional", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "conv.Conv", "conv.DWConv", "conv.GhostConv", "conv.LightConv", "conv.RepConv", "conv.autopad", "transformer.TransformerBlock", "torchvision", "nn.Module" ], "chunk_id": "class_SPPELAN_54c70ab7" }, { "content": "class CBLinear(nn.Module):\n \"\"\"CBLinear.\"\"\"\n\n def __init__(self, c1: int, c2s: List[int], k: int = 1, s: int = 1, p: Optional[int] = None, g: int = 1):\n \"\"\"\n Initialize CBLinear module.\n\n Args:\n c1 (int): Input channels.\n c2s (List[int]): List of output channel sizes.\n k (int): Kernel size.\n s (int): Stride.\n p (int | None): Padding.\n g (int): Groups.\n \"\"\"\n super().__init__()\n self.c2s = c2s\n self.conv = nn.Conv2d(c1, sum(c2s), k, s, autopad(k, p), groups=g, bias=True)\n\n def forward(self, x: torch.Tensor) -> List[torch.Tensor]:\n \"\"\"Forward pass through CBLinear layer.\"\"\"\n return self.conv(x).split(self.c2s, dim=1)", "chunk_type": "class", "name": "CBLinear", "file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py", "start_line": 1026, "end_line": 1047, "start_col": 0, "end_col": 50, "parent_name": null, "docstring": "CBLinear.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.List", "typing.Optional", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "conv.Conv", "conv.DWConv", "conv.GhostConv", "conv.LightConv", "conv.RepConv", "conv.autopad", "transformer.TransformerBlock", "torchvision", "nn.Module" ], "chunk_id": "class_CBLinear_8170ffd3" }, { "content": "class CBFuse(nn.Module):\n \"\"\"CBFuse.\"\"\"\n\n def __init__(self, idx: List[int]):\n \"\"\"\n Initialize CBFuse module.\n\n Args:\n idx (List[int]): Indices for feature selection.\n \"\"\"\n super().__init__()\n self.idx = idx\n\n def forward(self, xs: List[torch.Tensor]) -> torch.Tensor:\n \"\"\"\n Forward pass through CBFuse layer.\n\n Args:\n xs (List[torch.Tensor]): List of input tensors.\n\n Returns:\n (torch.Tensor): Fused output tensor.\n \"\"\"\n target_size = xs[-1].shape[2:]\n res = [F.interpolate(x[self.idx[i]], size=target_size, mode=\"nearest\") for i, x in enumerate(xs[:-1])]\n return torch.sum(torch.stack(res + xs[-1:]), dim=0)", "chunk_type": "class", "name": "CBFuse", "file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py", "start_line": 1050, "end_line": 1075, "start_col": 0, "end_col": 59, "parent_name": null, "docstring": "CBFuse.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.List", "typing.Optional", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "conv.Conv", "conv.DWConv", "conv.GhostConv", "conv.LightConv", "conv.RepConv", "conv.autopad", "transformer.TransformerBlock", "torchvision", "nn.Module" ], "chunk_id": "class_CBFuse_36c0c305" }, { "content": "class C3f(nn.Module):\n \"\"\"Faster Implementation of CSP Bottleneck with 2 convolutions.\"\"\"\n\n def __init__(self, c1: int, c2: int, n: int = 1, shortcut: bool = False, g: int = 1, e: float = 0.5):\n \"\"\"\n Initialize CSP bottleneck layer with two convolutions.\n\n Args:\n c1 (int): Input channels.\n c2 (int): Output channels.\n n (int): Number of Bottleneck blocks.\n shortcut (bool): Whether to use shortcut connections.\n g (int): Groups for convolutions.\n e (float): Expansion ratio.\n \"\"\"\n super().__init__()\n c_ = int(c2 * e) # hidden channels\n self.cv1 = Conv(c1, c_, 1, 1)\n self.cv2 = Conv(c1, c_, 1, 1)\n self.cv3 = Conv((2 + n) * c_, c2, 1) # optional act=FReLU(c2)\n self.m = nn.ModuleList(Bottleneck(c_, c_, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Forward pass through C3f layer.\"\"\"\n y = [self.cv2(x), self.cv1(x)]\n y.extend(m(y[-1]) for m in self.m)\n return self.cv3(torch.cat(y, 1))", "chunk_type": "class", "name": "C3f", "file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py", "start_line": 1078, "end_line": 1104, "start_col": 0, "end_col": 40, "parent_name": null, "docstring": "Faster Implementation of CSP Bottleneck with 2 convolutions.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.List", "typing.Optional", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "conv.Conv", "conv.DWConv", "conv.GhostConv", "conv.LightConv", "conv.RepConv", "conv.autopad", "transformer.TransformerBlock", "torchvision", "nn.Module" ], "chunk_id": "class_C3f_4070fa33" }, { "content": "class C3k2(C2f):\n \"\"\"Faster Implementation of CSP Bottleneck with 2 convolutions.\"\"\"\n\n def __init__(\n self, c1: int, c2: int, n: int = 1, c3k: bool = False, e: float = 0.5, g: int = 1, shortcut: bool = True\n ):\n \"\"\"\n Initialize C3k2 module.\n\n Args:\n c1 (int): Input channels.\n c2 (int): Output channels.\n n (int): Number of blocks.\n c3k (bool): Whether to use C3k blocks.\n e (float): Expansion ratio.\n g (int): Groups for convolutions.\n shortcut (bool): Whether to use shortcut connections.\n \"\"\"\n super().__init__(c1, c2, n, shortcut, g, e)\n self.m = nn.ModuleList(\n C3k(self.c, self.c, 2, shortcut, g) if c3k else Bottleneck(self.c, self.c, shortcut, g) for _ in range(n)\n )", "chunk_type": "class", "name": "C3k2", "file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py", "start_line": 1107, "end_line": 1128, "start_col": 0, "end_col": 9, "parent_name": null, "docstring": "Faster Implementation of CSP Bottleneck with 2 convolutions.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.List", "typing.Optional", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "conv.Conv", "conv.DWConv", "conv.GhostConv", "conv.LightConv", "conv.RepConv", "conv.autopad", "transformer.TransformerBlock", "torchvision", "C2f" ], "chunk_id": "class_C3k2_37260c67" }, { "content": "class C3k(C3):\n \"\"\"C3k is a CSP bottleneck module with customizable kernel sizes for feature extraction in neural networks.\"\"\"\n\n def __init__(self, c1: int, c2: int, n: int = 1, shortcut: bool = True, g: int = 1, e: float = 0.5, k: int = 3):\n \"\"\"\n Initialize C3k module.\n\n Args:\n c1 (int): Input channels.\n c2 (int): Output channels.\n n (int): Number of Bottleneck blocks.\n shortcut (bool): Whether to use shortcut connections.\n g (int): Groups for convolutions.\n e (float): Expansion ratio.\n k (int): Kernel size.\n \"\"\"\n super().__init__(c1, c2, n, shortcut, g, e)\n c_ = int(c2 * e) # hidden channels\n # self.m = nn.Sequential(*(RepBottleneck(c_, c_, shortcut, g, k=(k, k), e=1.0) for _ in range(n)))\n self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, k=(k, k), e=1.0) for _ in range(n)))", "chunk_type": "class", "name": "C3k", "file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py", "start_line": 1131, "end_line": 1150, "start_col": 0, "end_col": 101, "parent_name": null, "docstring": "C3k is a CSP bottleneck module with customizable kernel sizes for feature extraction in neural networks.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.List", "typing.Optional", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "conv.Conv", "conv.DWConv", "conv.GhostConv", "conv.LightConv", "conv.RepConv", "conv.autopad", "transformer.TransformerBlock", "torchvision", "C3" ], "chunk_id": "class_C3k_f04eef9c" }, { "content": "class RepVGGDW(torch.nn.Module):\n \"\"\"RepVGGDW is a class that represents a depth wise separable convolutional block in RepVGG architecture.\"\"\"\n\n def __init__(self, ed: int) -> None:\n \"\"\"\n Initialize RepVGGDW module.\n\n Args:\n ed (int): Input and output channels.\n \"\"\"\n super().__init__()\n self.conv = Conv(ed, ed, 7, 1, 3, g=ed, act=False)\n self.conv1 = Conv(ed, ed, 3, 1, 1, g=ed, act=False)\n self.dim = ed\n self.act = nn.SiLU()\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"\n Perform a forward pass of the RepVGGDW block.\n\n Args:\n x (torch.Tensor): Input tensor.\n\n Returns:\n (torch.Tensor): Output tensor after applying the depth wise separable convolution.\n \"\"\"\n return self.act(self.conv(x) + self.conv1(x))\n\n def forward_fuse(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"\n Perform a forward pass of the RepVGGDW block without fusing the convolutions.\n\n Args:\n x (torch.Tensor): Input tensor.\n\n Returns:\n (torch.Tensor): Output tensor after applying the depth wise separable convolution.\n \"\"\"\n return self.act(self.conv(x))\n\n @torch.no_grad()\n def fuse(self):\n \"\"\"\n Fuse the convolutional layers in the RepVGGDW block.\n\n This method fuses the convolutional layers and updates the weights and biases accordingly.\n \"\"\"\n conv = fuse_conv_and_bn(self.conv.conv, self.conv.bn)\n conv1 = fuse_conv_and_bn(self.conv1.conv, self.conv1.bn)\n\n conv_w = conv.weight\n conv_b = conv.bias\n conv1_w = conv1.weight\n conv1_b = conv1.bias\n\n conv1_w = torch.nn.functional.pad(conv1_w, [2, 2, 2, 2])\n\n final_conv_w = conv_w + conv1_w\n final_conv_b = conv_b + conv1_b\n\n conv.weight.data.copy_(final_conv_w)\n conv.bias.data.copy_(final_conv_b)\n\n self.conv = conv\n del self.conv1", "chunk_type": "class", "name": "RepVGGDW", "file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py", "start_line": 1153, "end_line": 1217, "start_col": 0, "end_col": 22, "parent_name": null, "docstring": "RepVGGDW is a class that represents a depth wise separable convolutional block in RepVGG architecture.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.List", "typing.Optional", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "conv.Conv", "conv.DWConv", "conv.GhostConv", "conv.LightConv", "conv.RepConv", "conv.autopad", "transformer.TransformerBlock", "torchvision", "torch.nn.Module" ], "chunk_id": "class_RepVGGDW_87a6072f" }, { "content": "class CIB(nn.Module):\n \"\"\"\n Conditional Identity Block (CIB) module.\n\n Args:\n c1 (int): Number of input channels.\n c2 (int): Number of output channels.\n shortcut (bool, optional): Whether to add a shortcut connection. Defaults to True.\n e (float, optional): Scaling factor for the hidden channels. Defaults to 0.5.\n lk (bool, optional): Whether to use RepVGGDW for the third convolutional layer. Defaults to False.\n \"\"\"\n\n def __init__(self, c1: int, c2: int, shortcut: bool = True, e: float = 0.5, lk: bool = False):\n \"\"\"\n Initialize the CIB module.\n\n Args:\n c1 (int): Input channels.\n c2 (int): Output channels.\n shortcut (bool): Whether to use shortcut connection.\n e (float): Expansion ratio.\n lk (bool): Whether to use RepVGGDW.\n \"\"\"\n super().__init__()\n c_ = int(c2 * e) # hidden channels\n self.cv1 = nn.Sequential(\n Conv(c1, c1, 3, g=c1),\n Conv(c1, 2 * c_, 1),\n RepVGGDW(2 * c_) if lk else Conv(2 * c_, 2 * c_, 3, g=2 * c_),\n Conv(2 * c_, c2, 1),\n Conv(c2, c2, 3, g=c2),\n )\n\n self.add = shortcut and c1 == c2\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"\n Forward pass of the CIB module.\n\n Args:\n x (torch.Tensor): Input tensor.\n\n Returns:\n (torch.Tensor): Output tensor.\n \"\"\"\n return x + self.cv1(x) if self.add else self.cv1(x)", "chunk_type": "class", "name": "CIB", "file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py", "start_line": 1220, "end_line": 1265, "start_col": 0, "end_col": 59, "parent_name": null, "docstring": "Conditional Identity Block (CIB) module.\n\nArgs:\n c1 (int): Number of input channels.\n c2 (int): Number of output channels.\n shortcut (bool, optional): Whether to add a shortcut connection. Defaults to True.\n e (float, optional): Scaling factor for the hidden channels. Defaults to 0.5.\n lk (bool, optional): Whether to use RepVGGDW for the third convolutional layer. Defaults to False.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.List", "typing.Optional", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "conv.Conv", "conv.DWConv", "conv.GhostConv", "conv.LightConv", "conv.RepConv", "conv.autopad", "transformer.TransformerBlock", "torchvision", "nn.Module" ], "chunk_id": "class_CIB_348da215" }, { "content": "class C2fCIB(C2f):\n \"\"\"\n C2fCIB class represents a convolutional block with C2f and CIB modules.\n\n Args:\n c1 (int): Number of input channels.\n c2 (int): Number of output channels.\n n (int, optional): Number of CIB modules to stack. Defaults to 1.\n shortcut (bool, optional): Whether to use shortcut connection. Defaults to False.\n lk (bool, optional): Whether to use local key connection. Defaults to False.\n g (int, optional): Number of groups for grouped convolution. Defaults to 1.\n e (float, optional): Expansion ratio for CIB modules. Defaults to 0.5.\n \"\"\"\n\n def __init__(\n self, c1: int, c2: int, n: int = 1, shortcut: bool = False, lk: bool = False, g: int = 1, e: float = 0.5\n ):\n \"\"\"\n Initialize C2fCIB module.\n\n Args:\n c1 (int): Input channels.\n c2 (int): Output channels.\n n (int): Number of CIB modules.\n shortcut (bool): Whether to use shortcut connection.\n lk (bool): Whether to use local key connection.\n g (int): Groups for convolutions.\n e (float): Expansion ratio.\n \"\"\"\n super().__init__(c1, c2, n, shortcut, g, e)\n self.m = nn.ModuleList(CIB(self.c, self.c, shortcut, e=1.0, lk=lk) for _ in range(n))", "chunk_type": "class", "name": "C2fCIB", "file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py", "start_line": 1268, "end_line": 1298, "start_col": 0, "end_col": 93, "parent_name": null, "docstring": "C2fCIB class represents a convolutional block with C2f and CIB modules.\n\nArgs:\n c1 (int): Number of input channels.\n c2 (int): Number of output channels.\n n (int, optional): Number of CIB modules to stack. Defaults to 1.\n shortcut (bool, optional): Whether to use shortcut connection. Defaults to False.\n lk (bool, optional): Whether to use local key connection. Defaults to False.\n g (int, optional): Number of groups for grouped convolution. Defaults to 1.\n e (float, optional): Expansion ratio for CIB modules. Defaults to 0.5.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.List", "typing.Optional", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "conv.Conv", "conv.DWConv", "conv.GhostConv", "conv.LightConv", "conv.RepConv", "conv.autopad", "transformer.TransformerBlock", "torchvision", "C2f" ], "chunk_id": "class_C2fCIB_ddc7b517" }, { "content": "class Attention(nn.Module):\n \"\"\"\n Attention module that performs self-attention on the input tensor.\n\n Args:\n dim (int): The input tensor dimension.\n num_heads (int): The number of attention heads.\n attn_ratio (float): The ratio of the attention key dimension to the head dimension.\n\n Attributes:\n num_heads (int): The number of attention heads.\n head_dim (int): The dimension of each attention head.\n key_dim (int): The dimension of the attention key.\n scale (float): The scaling factor for the attention scores.\n qkv (Conv): Convolutional layer for computing the query, key, and value.\n proj (Conv): Convolutional layer for projecting the attended values.\n pe (Conv): Convolutional layer for positional encoding.\n \"\"\"\n\n def __init__(self, dim: int, num_heads: int = 8, attn_ratio: float = 0.5):\n \"\"\"\n Initialize multi-head attention module.\n\n Args:\n dim (int): Input dimension.\n num_heads (int): Number of attention heads.\n attn_ratio (float): Attention ratio for key dimension.\n \"\"\"\n super().__init__()\n self.num_heads = num_heads\n self.head_dim = dim // num_heads\n self.key_dim = int(self.head_dim * attn_ratio)\n self.scale = self.key_dim**-0.5\n nh_kd = self.key_dim * num_heads\n h = dim + nh_kd * 2\n self.qkv = Conv(dim, h, 1, act=False)\n self.proj = Conv(dim, dim, 1, act=False)\n self.pe = Conv(dim, dim, 3, 1, g=dim, act=False)\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"\n Forward pass of the Attention module.\n\n Args:\n x (torch.Tensor): The input tensor.\n\n Returns:\n (torch.Tensor): The output tensor after self-attention.\n \"\"\"\n B, C, H, W = x.shape\n N = H * W\n qkv = self.qkv(x)\n q, k, v = qkv.view(B, self.num_heads, self.key_dim * 2 + self.head_dim, N).split(\n [self.key_dim, self.key_dim, self.head_dim], dim=2\n )\n\n attn = (q.transpose(-2, -1) @ k) * self.scale\n attn = attn.softmax(dim=-1)\n x = (v @ attn.transpose(-2, -1)).view(B, C, H, W) + self.pe(v.reshape(B, C, H, W))\n x = self.proj(x)\n return x", "chunk_type": "class", "name": "Attention", "file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py", "start_line": 1301, "end_line": 1361, "start_col": 0, "end_col": 16, "parent_name": null, "docstring": "Attention module that performs self-attention on the input tensor.\n\nArgs:\n dim (int): The input tensor dimension.\n num_heads (int): The number of attention heads.\n attn_ratio (float): The ratio of the attention key dimension to the head dimension.\n\nAttributes:\n num_heads (int): The number of attention heads.\n head_dim (int): The dimension of each attention head.\n key_dim (int): The dimension of the attention key.\n scale (float): The scaling factor for the attention scores.\n qkv (Conv): Convolutional layer for computing the query, key, and value.\n proj (Conv): Convolutional layer for projecting the attended values.\n pe (Conv): Convolutional layer for positional encoding.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.List", "typing.Optional", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "conv.Conv", "conv.DWConv", "conv.GhostConv", "conv.LightConv", "conv.RepConv", "conv.autopad", "transformer.TransformerBlock", "torchvision", "nn.Module" ], "chunk_id": "class_Attention_0b70783b" }, { "content": "class PSABlock(nn.Module):\n \"\"\"\n PSABlock class implementing a Position-Sensitive Attention block for neural networks.\n\n This class encapsulates the functionality for applying multi-head attention and feed-forward neural network layers\n with optional shortcut connections.\n\n Attributes:\n attn (Attention): Multi-head attention module.\n ffn (nn.Sequential): Feed-forward neural network module.\n add (bool): Flag indicating whether to add shortcut connections.\n\n Methods:\n forward: Performs a forward pass through the PSABlock, applying attention and feed-forward layers.\n\n Examples:\n Create a PSABlock and perform a forward pass\n >>> psablock = PSABlock(c=128, attn_ratio=0.5, num_heads=4, shortcut=True)\n >>> input_tensor = torch.randn(1, 128, 32, 32)\n >>> output_tensor = psablock(input_tensor)\n \"\"\"\n\n def __init__(self, c: int, attn_ratio: float = 0.5, num_heads: int = 4, shortcut: bool = True) -> None:\n \"\"\"\n Initialize the PSABlock.\n\n Args:\n c (int): Input and output channels.\n attn_ratio (float): Attention ratio for key dimension.\n num_heads (int): Number of attention heads.\n shortcut (bool): Whether to use shortcut connections.\n \"\"\"\n super().__init__()\n\n self.attn = Attention(c, attn_ratio=attn_ratio, num_heads=num_heads)\n self.ffn = nn.Sequential(Conv(c, c * 2, 1), Conv(c * 2, c, 1, act=False))\n self.add = shortcut\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"\n Execute a forward pass through PSABlock.\n\n Args:\n x (torch.Tensor): Input tensor.\n\n Returns:\n (torch.Tensor): Output tensor after attention and feed-forward processing.\n \"\"\"\n x = x + self.attn(x) if self.add else self.attn(x)\n x = x + self.ffn(x) if self.add else self.ffn(x)\n return x", "chunk_type": "class", "name": "PSABlock", "file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py", "start_line": 1364, "end_line": 1414, "start_col": 0, "end_col": 16, "parent_name": null, "docstring": "PSABlock class implementing a Position-Sensitive Attention block for neural networks.\n\nThis class encapsulates the functionality for applying multi-head attention and feed-forward neural network layers\nwith optional shortcut connections.\n\nAttributes:\n attn (Attention): Multi-head attention module.\n ffn (nn.Sequential): Feed-forward neural network module.\n add (bool): Flag indicating whether to add shortcut connections.\n\nMethods:\n forward: Performs a forward pass through the PSABlock, applying attention and feed-forward layers.\n\nExamples:\n Create a PSABlock and perform a forward pass\n >>> psablock = PSABlock(c=128, attn_ratio=0.5, num_heads=4, shortcut=True)\n >>> input_tensor = torch.randn(1, 128, 32, 32)\n >>> output_tensor = psablock(input_tensor)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.List", "typing.Optional", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "conv.Conv", "conv.DWConv", "conv.GhostConv", "conv.LightConv", "conv.RepConv", "conv.autopad", "transformer.TransformerBlock", "torchvision", "nn.Module" ], "chunk_id": "class_PSABlock_04e7ec49" }, { "content": "class PSA(nn.Module):\n \"\"\"\n PSA class for implementing Position-Sensitive Attention in neural networks.\n\n This class encapsulates the functionality for applying position-sensitive attention and feed-forward networks to\n input tensors, enhancing feature extraction and processing capabilities.\n\n Attributes:\n c (int): Number of hidden channels after applying the initial convolution.\n cv1 (Conv): 1x1 convolution layer to reduce the number of input channels to 2*c.\n cv2 (Conv): 1x1 convolution layer to reduce the number of output channels to c.\n attn (Attention): Attention module for position-sensitive attention.\n ffn (nn.Sequential): Feed-forward network for further processing.\n\n Methods:\n forward: Applies position-sensitive attention and feed-forward network to the input tensor.\n\n Examples:\n Create a PSA module and apply it to an input tensor\n >>> psa = PSA(c1=128, c2=128, e=0.5)\n >>> input_tensor = torch.randn(1, 128, 64, 64)\n >>> output_tensor = psa.forward(input_tensor)\n \"\"\"\n\n def __init__(self, c1: int, c2: int, e: float = 0.5):\n \"\"\"\n Initialize PSA module.\n\n Args:\n c1 (int): Input channels.\n c2 (int): Output channels.\n e (float): Expansion ratio.\n \"\"\"\n super().__init__()\n assert c1 == c2\n self.c = int(c1 * e)\n self.cv1 = Conv(c1, 2 * self.c, 1, 1)\n self.cv2 = Conv(2 * self.c, c1, 1)\n\n self.attn = Attention(self.c, attn_ratio=0.5, num_heads=self.c // 64)\n self.ffn = nn.Sequential(Conv(self.c, self.c * 2, 1), Conv(self.c * 2, self.c, 1, act=False))\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"\n Execute forward pass in PSA module.\n\n Args:\n x (torch.Tensor): Input tensor.\n\n Returns:\n (torch.Tensor): Output tensor after attention and feed-forward processing.\n \"\"\"\n a, b = self.cv1(x).split((self.c, self.c), dim=1)\n b = b + self.attn(b)\n b = b + self.ffn(b)\n return self.cv2(torch.cat((a, b), 1))", "chunk_type": "class", "name": "PSA", "file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py", "start_line": 1417, "end_line": 1472, "start_col": 0, "end_col": 45, "parent_name": null, "docstring": "PSA class for implementing Position-Sensitive Attention in neural networks.\n\nThis class encapsulates the functionality for applying position-sensitive attention and feed-forward networks to\ninput tensors, enhancing feature extraction and processing capabilities.\n\nAttributes:\n c (int): Number of hidden channels after applying the initial convolution.\n cv1 (Conv): 1x1 convolution layer to reduce the number of input channels to 2*c.\n cv2 (Conv): 1x1 convolution layer to reduce the number of output channels to c.\n attn (Attention): Attention module for position-sensitive attention.\n ffn (nn.Sequential): Feed-forward network for further processing.\n\nMethods:\n forward: Applies position-sensitive attention and feed-forward network to the input tensor.\n\nExamples:\n Create a PSA module and apply it to an input tensor\n >>> psa = PSA(c1=128, c2=128, e=0.5)\n >>> input_tensor = torch.randn(1, 128, 64, 64)\n >>> output_tensor = psa.forward(input_tensor)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.List", "typing.Optional", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "conv.Conv", "conv.DWConv", "conv.GhostConv", "conv.LightConv", "conv.RepConv", "conv.autopad", "transformer.TransformerBlock", "torchvision", "nn.Module" ], "chunk_id": "class_PSA_92a2933c" }, { "content": "class C2PSA(nn.Module):\n \"\"\"\n C2PSA module with attention mechanism for enhanced feature extraction and processing.\n\n This module implements a convolutional block with attention mechanisms to enhance feature extraction and processing\n capabilities. It includes a series of PSABlock modules for self-attention and feed-forward operations.\n\n Attributes:\n c (int): Number of hidden channels.\n cv1 (Conv): 1x1 convolution layer to reduce the number of input channels to 2*c.\n cv2 (Conv): 1x1 convolution layer to reduce the number of output channels to c.\n m (nn.Sequential): Sequential container of PSABlock modules for attention and feed-forward operations.\n\n Methods:\n forward: Performs a forward pass through the C2PSA module, applying attention and feed-forward operations.\n\n Notes:\n This module essentially is the same as PSA module, but refactored to allow stacking more PSABlock modules.\n\n Examples:\n >>> c2psa = C2PSA(c1=256, c2=256, n=3, e=0.5)\n >>> input_tensor = torch.randn(1, 256, 64, 64)\n >>> output_tensor = c2psa(input_tensor)\n \"\"\"\n\n def __init__(self, c1: int, c2: int, n: int = 1, e: float = 0.5):\n \"\"\"\n Initialize C2PSA module.\n\n Args:\n c1 (int): Input channels.\n c2 (int): Output channels.\n n (int): Number of PSABlock modules.\n e (float): Expansion ratio.\n \"\"\"\n super().__init__()\n assert c1 == c2\n self.c = int(c1 * e)\n self.cv1 = Conv(c1, 2 * self.c, 1, 1)\n self.cv2 = Conv(2 * self.c, c1, 1)\n\n self.m = nn.Sequential(*(PSABlock(self.c, attn_ratio=0.5, num_heads=self.c // 64) for _ in range(n)))\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"\n Process the input tensor through a series of PSA blocks.\n\n Args:\n x (torch.Tensor): Input tensor.\n\n Returns:\n (torch.Tensor): Output tensor after processing.\n \"\"\"\n a, b = self.cv1(x).split((self.c, self.c), dim=1)\n b = self.m(b)\n return self.cv2(torch.cat((a, b), 1))", "chunk_type": "class", "name": "C2PSA", "file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py", "start_line": 1475, "end_line": 1530, "start_col": 0, "end_col": 45, "parent_name": null, "docstring": "C2PSA module with attention mechanism for enhanced feature extraction and processing.\n\nThis module implements a convolutional block with attention mechanisms to enhance feature extraction and processing\ncapabilities. It includes a series of PSABlock modules for self-attention and feed-forward operations.\n\nAttributes:\n c (int): Number of hidden channels.\n cv1 (Conv): 1x1 convolution layer to reduce the number of input channels to 2*c.\n cv2 (Conv): 1x1 convolution layer to reduce the number of output channels to c.\n m (nn.Sequential): Sequential container of PSABlock modules for attention and feed-forward operations.\n\nMethods:\n forward: Performs a forward pass through the C2PSA module, applying attention and feed-forward operations.\n\nNotes:\n This module essentially is the same as PSA module, but refactored to allow stacking more PSABlock modules.\n\nExamples:\n >>> c2psa = C2PSA(c1=256, c2=256, n=3, e=0.5)\n >>> input_tensor = torch.randn(1, 256, 64, 64)\n >>> output_tensor = c2psa(input_tensor)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.List", "typing.Optional", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "conv.Conv", "conv.DWConv", "conv.GhostConv", "conv.LightConv", "conv.RepConv", "conv.autopad", "transformer.TransformerBlock", "torchvision", "nn.Module" ], "chunk_id": "class_C2PSA_f5974732" }, { "content": "class C2fPSA(C2f):\n \"\"\"\n C2fPSA module with enhanced feature extraction using PSA blocks.\n\n This class extends the C2f module by incorporating PSA blocks for improved attention mechanisms and feature extraction.\n\n Attributes:\n c (int): Number of hidden channels.\n cv1 (Conv): 1x1 convolution layer to reduce the number of input channels to 2*c.\n cv2 (Conv): 1x1 convolution layer to reduce the number of output channels to c.\n m (nn.ModuleList): List of PSA blocks for feature extraction.\n\n Methods:\n forward: Performs a forward pass through the C2fPSA module.\n forward_split: Performs a forward pass using split() instead of chunk().\n\n Examples:\n >>> import torch\n >>> from ultralytics.models.common import C2fPSA\n >>> model = C2fPSA(c1=64, c2=64, n=3, e=0.5)\n >>> x = torch.randn(1, 64, 128, 128)\n >>> output = model(x)\n >>> print(output.shape)\n \"\"\"\n\n def __init__(self, c1: int, c2: int, n: int = 1, e: float = 0.5):\n \"\"\"\n Initialize C2fPSA module.\n\n Args:\n c1 (int): Input channels.\n c2 (int): Output channels.\n n (int): Number of PSABlock modules.\n e (float): Expansion ratio.\n \"\"\"\n assert c1 == c2\n super().__init__(c1, c2, n=n, e=e)\n self.m = nn.ModuleList(PSABlock(self.c, attn_ratio=0.5, num_heads=self.c // 64) for _ in range(n))", "chunk_type": "class", "name": "C2fPSA", "file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py", "start_line": 1533, "end_line": 1570, "start_col": 0, "end_col": 106, "parent_name": null, "docstring": "C2fPSA module with enhanced feature extraction using PSA blocks.\n\nThis class extends the C2f module by incorporating PSA blocks for improved attention mechanisms and feature extraction.\n\nAttributes:\n c (int): Number of hidden channels.\n cv1 (Conv): 1x1 convolution layer to reduce the number of input channels to 2*c.\n cv2 (Conv): 1x1 convolution layer to reduce the number of output channels to c.\n m (nn.ModuleList): List of PSA blocks for feature extraction.\n\nMethods:\n forward: Performs a forward pass through the C2fPSA module.\n forward_split: Performs a forward pass using split() instead of chunk().\n\nExamples:\n >>> import torch\n >>> from ultralytics.models.common import C2fPSA\n >>> model = C2fPSA(c1=64, c2=64, n=3, e=0.5)\n >>> x = torch.randn(1, 64, 128, 128)\n >>> output = model(x)\n >>> print(output.shape)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.List", "typing.Optional", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "conv.Conv", "conv.DWConv", "conv.GhostConv", "conv.LightConv", "conv.RepConv", "conv.autopad", "transformer.TransformerBlock", "torchvision", "C2f" ], "chunk_id": "class_C2fPSA_8f7445a7" }, { "content": "class SCDown(nn.Module):\n \"\"\"\n SCDown module for downsampling with separable convolutions.\n\n This module performs downsampling using a combination of pointwise and depthwise convolutions, which helps in\n efficiently reducing the spatial dimensions of the input tensor while maintaining the channel information.\n\n Attributes:\n cv1 (Conv): Pointwise convolution layer that reduces the number of channels.\n cv2 (Conv): Depthwise convolution layer that performs spatial downsampling.\n\n Methods:\n forward: Applies the SCDown module to the input tensor.\n\n Examples:\n >>> import torch\n >>> from ultralytics import SCDown\n >>> model = SCDown(c1=64, c2=128, k=3, s=2)\n >>> x = torch.randn(1, 64, 128, 128)\n >>> y = model(x)\n >>> print(y.shape)\n torch.Size([1, 128, 64, 64])\n \"\"\"\n\n def __init__(self, c1: int, c2: int, k: int, s: int):\n \"\"\"\n Initialize SCDown module.\n\n Args:\n c1 (int): Input channels.\n c2 (int): Output channels.\n k (int): Kernel size.\n s (int): Stride.\n \"\"\"\n super().__init__()\n self.cv1 = Conv(c1, c2, 1, 1)\n self.cv2 = Conv(c2, c2, k=k, s=s, g=c2, act=False)\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"\n Apply convolution and downsampling to the input tensor.\n\n Args:\n x (torch.Tensor): Input tensor.\n\n Returns:\n (torch.Tensor): Downsampled output tensor.\n \"\"\"\n return self.cv2(self.cv1(x))", "chunk_type": "class", "name": "SCDown", "file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py", "start_line": 1573, "end_line": 1621, "start_col": 0, "end_col": 36, "parent_name": null, "docstring": "SCDown module for downsampling with separable convolutions.\n\nThis module performs downsampling using a combination of pointwise and depthwise convolutions, which helps in\nefficiently reducing the spatial dimensions of the input tensor while maintaining the channel information.\n\nAttributes:\n cv1 (Conv): Pointwise convolution layer that reduces the number of channels.\n cv2 (Conv): Depthwise convolution layer that performs spatial downsampling.\n\nMethods:\n forward: Applies the SCDown module to the input tensor.\n\nExamples:\n >>> import torch\n >>> from ultralytics import SCDown\n >>> model = SCDown(c1=64, c2=128, k=3, s=2)\n >>> x = torch.randn(1, 64, 128, 128)\n >>> y = model(x)\n >>> print(y.shape)\n torch.Size([1, 128, 64, 64])", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.List", "typing.Optional", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "conv.Conv", "conv.DWConv", "conv.GhostConv", "conv.LightConv", "conv.RepConv", "conv.autopad", "transformer.TransformerBlock", "torchvision", "nn.Module" ], "chunk_id": "class_SCDown_49a4a000" }, { "content": "class TorchVision(nn.Module):\n \"\"\"\n TorchVision module to allow loading any torchvision model.\n\n This class provides a way to load a model from the torchvision library, optionally load pre-trained weights, and customize the model by truncating or unwrapping layers.\n\n Attributes:\n m (nn.Module): The loaded torchvision model, possibly truncated and unwrapped.\n\n Args:\n model (str): Name of the torchvision model to load.\n weights (str, optional): Pre-trained weights to load. Default is \"DEFAULT\".\n unwrap (bool, optional): If True, unwraps the model to a sequential containing all but the last `truncate` layers. Default is True.\n truncate (int, optional): Number of layers to truncate from the end if `unwrap` is True. Default is 2.\n split (bool, optional): Returns output from intermediate child modules as list. Default is False.\n \"\"\"\n\n def __init__(\n self, model: str, weights: str = \"DEFAULT\", unwrap: bool = True, truncate: int = 2, split: bool = False\n ):\n \"\"\"\n Load the model and weights from torchvision.\n\n Args:\n model (str): Name of the torchvision model to load.\n weights (str): Pre-trained weights to load.\n unwrap (bool): Whether to unwrap the model.\n truncate (int): Number of layers to truncate.\n split (bool): Whether to split the output.\n \"\"\"\n import torchvision # scope for faster 'import ultralytics'\n\n super().__init__()\n if hasattr(torchvision.models, \"get_model\"):\n self.m = torchvision.models.get_model(model, weights=weights)\n else:\n self.m = torchvision.models.__dict__[model](pretrained=bool(weights))\n if unwrap:\n layers = list(self.m.children())\n if isinstance(layers[0], nn.Sequential): # Second-level for some models like EfficientNet, Swin\n layers = [*list(layers[0].children()), *layers[1:]]\n self.m = nn.Sequential(*(layers[:-truncate] if truncate else layers))\n self.split = split\n else:\n self.split = False\n self.m.head = self.m.heads = nn.Identity()\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"\n Forward pass through the model.\n\n Args:\n x (torch.Tensor): Input tensor.\n\n Returns:\n (torch.Tensor | List[torch.Tensor]): Output tensor or list of tensors.\n \"\"\"\n if self.split:\n y = [x]\n y.extend(m(y[-1]) for m in self.m)\n else:\n y = self.m(x)\n return y", "chunk_type": "class", "name": "TorchVision", "file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py", "start_line": 1624, "end_line": 1686, "start_col": 0, "end_col": 16, "parent_name": null, "docstring": "TorchVision module to allow loading any torchvision model.\n\nThis class provides a way to load a model from the torchvision library, optionally load pre-trained weights, and customize the model by truncating or unwrapping layers.\n\nAttributes:\n m (nn.Module): The loaded torchvision model, possibly truncated and unwrapped.\n\nArgs:\n model (str): Name of the torchvision model to load.\n weights (str, optional): Pre-trained weights to load. Default is \"DEFAULT\".\n unwrap (bool, optional): If True, unwraps the model to a sequential containing all but the last `truncate` layers. Default is True.\n truncate (int, optional): Number of layers to truncate from the end if `unwrap` is True. Default is 2.\n split (bool, optional): Returns output from intermediate child modules as list. Default is False.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.List", "typing.Optional", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "conv.Conv", "conv.DWConv", "conv.GhostConv", "conv.LightConv", "conv.RepConv", "conv.autopad", "transformer.TransformerBlock", "torchvision", "nn.Module" ], "chunk_id": "class_TorchVision_6883286a" }, { "content": "class AAttn(nn.Module):\n \"\"\"\n Area-attention module for YOLO models, providing efficient attention mechanisms.\n\n This module implements an area-based attention mechanism that processes input features in a spatially-aware manner,\n making it particularly effective for object detection tasks.\n\n Attributes:\n area (int): Number of areas the feature map is divided.\n num_heads (int): Number of heads into which the attention mechanism is divided.\n head_dim (int): Dimension of each attention head.\n qkv (Conv): Convolution layer for computing query, key and value tensors.\n proj (Conv): Projection convolution layer.\n pe (Conv): Position encoding convolution layer.\n\n Methods:\n forward: Applies area-attention to input tensor.\n\n Examples:\n >>> attn = AAttn(dim=256, num_heads=8, area=4)\n >>> x = torch.randn(1, 256, 32, 32)\n >>> output = attn(x)\n >>> print(output.shape)\n torch.Size([1, 256, 32, 32])\n \"\"\"\n\n def __init__(self, dim: int, num_heads: int, area: int = 1):\n \"\"\"\n Initialize an Area-attention module for YOLO models.\n\n Args:\n dim (int): Number of hidden channels.\n num_heads (int): Number of heads into which the attention mechanism is divided.\n area (int): Number of areas the feature map is divided.\n \"\"\"\n super().__init__()\n self.area = area\n\n self.num_heads = num_heads\n self.head_dim = head_dim = dim // num_heads\n all_head_dim = head_dim * self.num_heads\n\n self.qkv = Conv(dim, all_head_dim * 3, 1, act=False)\n self.proj = Conv(all_head_dim, dim, 1, act=False)\n self.pe = Conv(all_head_dim, dim, 7, 1, 3, g=dim, act=False)\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"\n Process the input tensor through the area-attention.\n\n Args:\n x (torch.Tensor): Input tensor.\n\n Returns:\n (torch.Tensor): Output tensor after area-attention.\n \"\"\"\n B, C, H, W = x.shape\n N = H * W\n\n qkv = self.qkv(x).flatten(2).transpose(1, 2)\n if self.area > 1:\n qkv = qkv.reshape(B * self.area, N // self.area, C * 3)\n B, N, _ = qkv.shape\n q, k, v = (\n qkv.view(B, N, self.num_heads, self.head_dim * 3)\n .permute(0, 2, 3, 1)\n .split([self.head_dim, self.head_dim, self.head_dim], dim=2)\n )\n attn = (q.transpose(-2, -1) @ k) * (self.head_dim**-0.5)\n attn = attn.softmax(dim=-1)\n x = v @ attn.transpose(-2, -1)\n x = x.permute(0, 3, 1, 2)\n v = v.permute(0, 3, 1, 2)\n\n if self.area > 1:\n x = x.reshape(B // self.area, N * self.area, C)\n v = v.reshape(B // self.area, N * self.area, C)\n B, N, _ = x.shape\n\n x = x.reshape(B, H, W, C).permute(0, 3, 1, 2).contiguous()\n v = v.reshape(B, H, W, C).permute(0, 3, 1, 2).contiguous()\n\n x = x + self.pe(v)\n return self.proj(x)", "chunk_type": "class", "name": "AAttn", "file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py", "start_line": 1689, "end_line": 1772, "start_col": 0, "end_col": 27, "parent_name": null, "docstring": "Area-attention module for YOLO models, providing efficient attention mechanisms.\n\nThis module implements an area-based attention mechanism that processes input features in a spatially-aware manner,\nmaking it particularly effective for object detection tasks.\n\nAttributes:\n area (int): Number of areas the feature map is divided.\n num_heads (int): Number of heads into which the attention mechanism is divided.\n head_dim (int): Dimension of each attention head.\n qkv (Conv): Convolution layer for computing query, key and value tensors.\n proj (Conv): Projection convolution layer.\n pe (Conv): Position encoding convolution layer.\n\nMethods:\n forward: Applies area-attention to input tensor.\n\nExamples:\n >>> attn = AAttn(dim=256, num_heads=8, area=4)\n >>> x = torch.randn(1, 256, 32, 32)\n >>> output = attn(x)\n >>> print(output.shape)\n torch.Size([1, 256, 32, 32])", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.List", "typing.Optional", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "conv.Conv", "conv.DWConv", "conv.GhostConv", "conv.LightConv", "conv.RepConv", "conv.autopad", "transformer.TransformerBlock", "torchvision", "nn.Module" ], "chunk_id": "class_AAttn_f769942c" }, { "content": "class ABlock(nn.Module):\n \"\"\"\n Area-attention block module for efficient feature extraction in YOLO models.\n\n This module implements an area-attention mechanism combined with a feed-forward network for processing feature maps.\n It uses a novel area-based attention approach that is more efficient than traditional self-attention while\n maintaining effectiveness.\n\n Attributes:\n attn (AAttn): Area-attention module for processing spatial features.\n mlp (nn.Sequential): Multi-layer perceptron for feature transformation.\n\n Methods:\n _init_weights: Initializes module weights using truncated normal distribution.\n forward: Applies area-attention and feed-forward processing to input tensor.\n\n Examples:\n >>> block = ABlock(dim=256, num_heads=8, mlp_ratio=1.2, area=1)\n >>> x = torch.randn(1, 256, 32, 32)\n >>> output = block(x)\n >>> print(output.shape)\n torch.Size([1, 256, 32, 32])\n \"\"\"\n\n def __init__(self, dim: int, num_heads: int, mlp_ratio: float = 1.2, area: int = 1):\n \"\"\"\n Initialize an Area-attention block module.\n\n Args:\n dim (int): Number of input channels.\n num_heads (int): Number of heads into which the attention mechanism is divided.\n mlp_ratio (float): Expansion ratio for MLP hidden dimension.\n area (int): Number of areas the feature map is divided.\n \"\"\"\n super().__init__()\n\n self.attn = AAttn(dim, num_heads=num_heads, area=area)\n mlp_hidden_dim = int(dim * mlp_ratio)\n self.mlp = nn.Sequential(Conv(dim, mlp_hidden_dim, 1), Conv(mlp_hidden_dim, dim, 1, act=False))\n\n self.apply(self._init_weights)\n\n def _init_weights(self, m: nn.Module):\n \"\"\"\n Initialize weights using a truncated normal distribution.\n\n Args:\n m (nn.Module): Module to initialize.\n \"\"\"\n if isinstance(m, nn.Conv2d):\n nn.init.trunc_normal_(m.weight, std=0.02)\n if m.bias is not None:\n nn.init.constant_(m.bias, 0)\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"\n Forward pass through ABlock.\n\n Args:\n x (torch.Tensor): Input tensor.\n\n Returns:\n (torch.Tensor): Output tensor after area-attention and feed-forward processing.\n \"\"\"\n x = x + self.attn(x)\n return x + self.mlp(x)", "chunk_type": "class", "name": "ABlock", "file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py", "start_line": 1775, "end_line": 1840, "start_col": 0, "end_col": 30, "parent_name": null, "docstring": "Area-attention block module for efficient feature extraction in YOLO models.\n\nThis module implements an area-attention mechanism combined with a feed-forward network for processing feature maps.\nIt uses a novel area-based attention approach that is more efficient than traditional self-attention while\nmaintaining effectiveness.\n\nAttributes:\n attn (AAttn): Area-attention module for processing spatial features.\n mlp (nn.Sequential): Multi-layer perceptron for feature transformation.\n\nMethods:\n _init_weights: Initializes module weights using truncated normal distribution.\n forward: Applies area-attention and feed-forward processing to input tensor.\n\nExamples:\n >>> block = ABlock(dim=256, num_heads=8, mlp_ratio=1.2, area=1)\n >>> x = torch.randn(1, 256, 32, 32)\n >>> output = block(x)\n >>> print(output.shape)\n torch.Size([1, 256, 32, 32])", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.List", "typing.Optional", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "conv.Conv", "conv.DWConv", "conv.GhostConv", "conv.LightConv", "conv.RepConv", "conv.autopad", "transformer.TransformerBlock", "torchvision", "nn.Module" ], "chunk_id": "class_ABlock_d1e5ffc8" }, { "content": "class A2C2f(nn.Module):\n \"\"\"\n Area-Attention C2f module for enhanced feature extraction with area-based attention mechanisms.\n\n This module extends the C2f architecture by incorporating area-attention and ABlock layers for improved feature\n processing. It supports both area-attention and standard convolution modes.\n\n Attributes:\n cv1 (Conv): Initial 1x1 convolution layer that reduces input channels to hidden channels.\n cv2 (Conv): Final 1x1 convolution layer that processes concatenated features.\n gamma (nn.Parameter | None): Learnable parameter for residual scaling when using area attention.\n m (nn.ModuleList): List of either ABlock or C3k modules for feature processing.\n\n Methods:\n forward: Processes input through area-attention or standard convolution pathway.\n\n Examples:\n >>> m = A2C2f(512, 512, n=1, a2=True, area=1)\n >>> x = torch.randn(1, 512, 32, 32)\n >>> output = m(x)\n >>> print(output.shape)\n torch.Size([1, 512, 32, 32])\n \"\"\"\n\n def __init__(\n self,\n c1: int,\n c2: int,\n n: int = 1,\n a2: bool = True,\n area: int = 1,\n residual: bool = False,\n mlp_ratio: float = 2.0,\n e: float = 0.5,\n g: int = 1,\n shortcut: bool = True,\n ):\n \"\"\"\n Initialize Area-Attention C2f module.\n\n Args:\n c1 (int): Number of input channels.\n c2 (int): Number of output channels.\n n (int): Number of ABlock or C3k modules to stack.\n a2 (bool): Whether to use area attention blocks. If False, uses C3k blocks instead.\n area (int): Number of areas the feature map is divided.\n residual (bool): Whether to use residual connections with learnable gamma parameter.\n mlp_ratio (float): Expansion ratio for MLP hidden dimension.\n e (float): Channel expansion ratio for hidden channels.\n g (int): Number of groups for grouped convolutions.\n shortcut (bool): Whether to use shortcut connections in C3k blocks.\n \"\"\"\n super().__init__()\n c_ = int(c2 * e) # hidden channels\n assert c_ % 32 == 0, \"Dimension of ABlock be a multiple of 32.\"\n\n self.cv1 = Conv(c1, c_, 1, 1)\n self.cv2 = Conv((1 + n) * c_, c2, 1)\n\n self.gamma = nn.Parameter(0.01 * torch.ones(c2), requires_grad=True) if a2 and residual else None\n self.m = nn.ModuleList(\n nn.Sequential(*(ABlock(c_, c_ // 32, mlp_ratio, area) for _ in range(2)))\n if a2\n else C3k(c_, c_, 2, shortcut, g)\n for _ in range(n)\n )\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"\n Forward pass through A2C2f layer.\n\n Args:\n x (torch.Tensor): Input tensor.\n\n Returns:\n (torch.Tensor): Output tensor after processing.\n \"\"\"\n y = [self.cv1(x)]\n y.extend(m(y[-1]) for m in self.m)\n y = self.cv2(torch.cat(y, 1))\n if self.gamma is not None:\n return x + self.gamma.view(-1, len(self.gamma), 1, 1) * y\n return y", "chunk_type": "class", "name": "A2C2f", "file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py", "start_line": 1843, "end_line": 1925, "start_col": 0, "end_col": 16, "parent_name": null, "docstring": "Area-Attention C2f module for enhanced feature extraction with area-based attention mechanisms.\n\nThis module extends the C2f architecture by incorporating area-attention and ABlock layers for improved feature\nprocessing. It supports both area-attention and standard convolution modes.\n\nAttributes:\n cv1 (Conv): Initial 1x1 convolution layer that reduces input channels to hidden channels.\n cv2 (Conv): Final 1x1 convolution layer that processes concatenated features.\n gamma (nn.Parameter | None): Learnable parameter for residual scaling when using area attention.\n m (nn.ModuleList): List of either ABlock or C3k modules for feature processing.\n\nMethods:\n forward: Processes input through area-attention or standard convolution pathway.\n\nExamples:\n >>> m = A2C2f(512, 512, n=1, a2=True, area=1)\n >>> x = torch.randn(1, 512, 32, 32)\n >>> output = m(x)\n >>> print(output.shape)\n torch.Size([1, 512, 32, 32])", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.List", "typing.Optional", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "conv.Conv", "conv.DWConv", "conv.GhostConv", "conv.LightConv", "conv.RepConv", "conv.autopad", "transformer.TransformerBlock", "torchvision", "nn.Module" ], "chunk_id": "class_A2C2f_da7e61c0" }, { "content": "class SwiGLUFFN(nn.Module):\n \"\"\"SwiGLU Feed-Forward Network for transformer-based architectures.\"\"\"\n\n def __init__(self, gc: int, ec: int, e: int = 4) -> None:\n \"\"\"\n Initialize SwiGLU FFN with input dimension, output dimension, and expansion factor.\n\n Args:\n gc (int): Guide channels.\n ec (int): Embedding channels.\n e (int): Expansion factor.\n \"\"\"\n super().__init__()\n self.w12 = nn.Linear(gc, e * ec)\n self.w3 = nn.Linear(e * ec // 2, ec)\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Apply SwiGLU transformation to input features.\"\"\"\n x12 = self.w12(x)\n x1, x2 = x12.chunk(2, dim=-1)\n hidden = F.silu(x1) * x2\n return self.w3(hidden)", "chunk_type": "class", "name": "SwiGLUFFN", "file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py", "start_line": 1928, "end_line": 1949, "start_col": 0, "end_col": 30, "parent_name": null, "docstring": "SwiGLU Feed-Forward Network for transformer-based architectures.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.List", "typing.Optional", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "conv.Conv", "conv.DWConv", "conv.GhostConv", "conv.LightConv", "conv.RepConv", "conv.autopad", "transformer.TransformerBlock", "torchvision", "nn.Module" ], "chunk_id": "class_SwiGLUFFN_2e5d045a" }, { "content": "class Residual(nn.Module):\n \"\"\"Residual connection wrapper for neural network modules.\"\"\"\n\n def __init__(self, m: nn.Module) -> None:\n \"\"\"\n Initialize residual module with the wrapped module.\n\n Args:\n m (nn.Module): Module to wrap with residual connection.\n \"\"\"\n super().__init__()\n self.m = m\n nn.init.zeros_(self.m.w3.bias)\n # For models with l scale, please change the initialization to\n # nn.init.constant_(self.m.w3.weight, 1e-6)\n nn.init.zeros_(self.m.w3.weight)\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Apply residual connection to input features.\"\"\"\n return x + self.m(x)", "chunk_type": "class", "name": "Residual", "file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py", "start_line": 1952, "end_line": 1971, "start_col": 0, "end_col": 28, "parent_name": null, "docstring": "Residual connection wrapper for neural network modules.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.List", "typing.Optional", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "conv.Conv", "conv.DWConv", "conv.GhostConv", "conv.LightConv", "conv.RepConv", "conv.autopad", "transformer.TransformerBlock", "torchvision", "nn.Module" ], "chunk_id": "class_Residual_4db16cd0" }, { "content": "class SAVPE(nn.Module):\n \"\"\"Spatial-Aware Visual Prompt Embedding module for feature enhancement.\"\"\"\n\n def __init__(self, ch: List[int], c3: int, embed: int):\n \"\"\"\n Initialize SAVPE module with channels, intermediate channels, and embedding dimension.\n\n Args:\n ch (List[int]): List of input channel dimensions.\n c3 (int): Intermediate channels.\n embed (int): Embedding dimension.\n \"\"\"\n super().__init__()\n self.cv1 = nn.ModuleList(\n nn.Sequential(\n Conv(x, c3, 3), Conv(c3, c3, 3), nn.Upsample(scale_factor=i * 2) if i in {1, 2} else nn.Identity()\n )\n for i, x in enumerate(ch)\n )\n\n self.cv2 = nn.ModuleList(\n nn.Sequential(Conv(x, c3, 1), nn.Upsample(scale_factor=i * 2) if i in {1, 2} else nn.Identity())\n for i, x in enumerate(ch)\n )\n\n self.c = 16\n self.cv3 = nn.Conv2d(3 * c3, embed, 1)\n self.cv4 = nn.Conv2d(3 * c3, self.c, 3, padding=1)\n self.cv5 = nn.Conv2d(1, self.c, 3, padding=1)\n self.cv6 = nn.Sequential(Conv(2 * self.c, self.c, 3), nn.Conv2d(self.c, self.c, 3, padding=1))\n\n def forward(self, x: List[torch.Tensor], vp: torch.Tensor) -> torch.Tensor:\n \"\"\"Process input features and visual prompts to generate enhanced embeddings.\"\"\"\n y = [self.cv2[i](xi) for i, xi in enumerate(x)]\n y = self.cv4(torch.cat(y, dim=1))\n\n x = [self.cv1[i](xi) for i, xi in enumerate(x)]\n x = self.cv3(torch.cat(x, dim=1))\n\n B, C, H, W = x.shape\n\n Q = vp.shape[1]\n\n x = x.view(B, C, -1)\n\n y = y.reshape(B, 1, self.c, H, W).expand(-1, Q, -1, -1, -1).reshape(B * Q, self.c, H, W)\n vp = vp.reshape(B, Q, 1, H, W).reshape(B * Q, 1, H, W)\n\n y = self.cv6(torch.cat((y, self.cv5(vp)), dim=1))\n\n y = y.reshape(B, Q, self.c, -1)\n vp = vp.reshape(B, Q, 1, -1)\n\n score = y * vp + torch.logical_not(vp) * torch.finfo(y.dtype).min\n\n score = F.softmax(score, dim=-1, dtype=torch.float).to(score.dtype)\n\n aggregated = score.transpose(-2, -3) @ x.reshape(B, self.c, C // self.c, -1).transpose(-1, -2)\n\n return F.normalize(aggregated.transpose(-2, -3).reshape(B, Q, -1), dim=-1, p=2)", "chunk_type": "class", "name": "SAVPE", "file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py", "start_line": 1974, "end_line": 2033, "start_col": 0, "end_col": 87, "parent_name": null, "docstring": "Spatial-Aware Visual Prompt Embedding module for feature enhancement.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "typing.List", "typing.Optional", "typing.Tuple", "torch", "torch.nn", "torch.nn.functional", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "conv.Conv", "conv.DWConv", "conv.GhostConv", "conv.LightConv", "conv.RepConv", "conv.autopad", "transformer.TransformerBlock", "torchvision", "nn.Module" ], "chunk_id": "class_SAVPE_0cc8314c" }, { "content": "import math", "chunk_type": "import", "name": "math", "file_path": "ultralytics\\ultralytics\\nn\\modules\\conv.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 11, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_math_197fdc46" }, { "content": "from typing import List", "chunk_type": "import", "name": "List", "file_path": "ultralytics\\ultralytics\\nn\\modules\\conv.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 23, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_List_aa56bce6" }, { "content": "import numpy as np", "chunk_type": "import", "name": "numpy", "file_path": "ultralytics\\ultralytics\\nn\\modules\\conv.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_numpy_4e2a9c1b" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\ultralytics\\nn\\modules\\conv.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_905ccb83" }, { "content": "import torch.nn as nn", "chunk_type": "import", "name": "torch.nn", "file_path": "ultralytics\\ultralytics\\nn\\modules\\conv.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 21, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch.nn_ee8edbaf" }, { "content": "__all__ = (\n \"Conv\",\n \"Conv2\",\n \"LightConv\",\n \"DWConv\",\n \"DWConvTranspose2d\",\n \"ConvTranspose\",\n \"Focus\",\n \"GhostConv\",\n \"ChannelAttention\",\n \"SpatialAttention\",\n \"CBAM\",\n \"Concat\",\n \"RepConv\",\n \"Index\",\n)", "chunk_type": "variable", "name": "__all__", "file_path": "ultralytics\\ultralytics\\nn\\modules\\conv.py", "start_line": 11, "end_line": 26, "start_col": 0, "end_col": 1, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable___all___baa4b10a" }, { "content": "def autopad(k, p=None, d=1): # kernel, padding, dilation\n \"\"\"Pad to 'same' shape outputs.\"\"\"\n if d > 1:\n k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] # actual kernel-size\n if p is None:\n p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad\n return p", "chunk_type": "function", "name": "autopad", "file_path": "ultralytics\\ultralytics\\nn\\modules\\conv.py", "start_line": 29, "end_line": 35, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": "Pad to 'same' shape outputs.", "parameters": [ "k", "p", "d" ], "return_type": null, "decorators": [], "complexity_score": 5, "dependencies": [ "math", "typing.List", "numpy", "torch", "torch.nn" ], "chunk_id": "function_autopad_28ae3a8c" }, { "content": "class Conv(nn.Module):\n \"\"\"\n Standard convolution module with batch normalization and activation.\n\n Attributes:\n conv (nn.Conv2d): Convolutional layer.\n bn (nn.BatchNorm2d): Batch normalization layer.\n act (nn.Module): Activation function layer.\n default_act (nn.Module): Default activation function (SiLU).\n \"\"\"\n\n default_act = nn.SiLU() # default activation\n\n def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):\n \"\"\"\n Initialize Conv layer with given parameters.\n\n Args:\n c1 (int): Number of input channels.\n c2 (int): Number of output channels.\n k (int): Kernel size.\n s (int): Stride.\n p (int, optional): Padding.\n g (int): Groups.\n d (int): Dilation.\n act (bool | nn.Module): Activation function.\n \"\"\"\n super().__init__()\n self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)\n self.bn = nn.BatchNorm2d(c2)\n self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()\n\n def forward(self, x):\n \"\"\"\n Apply convolution, batch normalization and activation to input tensor.\n\n Args:\n x (torch.Tensor): Input tensor.\n\n Returns:\n (torch.Tensor): Output tensor.\n \"\"\"\n return self.act(self.bn(self.conv(x)))\n\n def forward_fuse(self, x):\n \"\"\"\n Apply convolution and activation without batch normalization.\n\n Args:\n x (torch.Tensor): Input tensor.\n\n Returns:\n (torch.Tensor): Output tensor.\n \"\"\"\n return self.act(self.conv(x))", "chunk_type": "class", "name": "Conv", "file_path": "ultralytics\\ultralytics\\nn\\modules\\conv.py", "start_line": 38, "end_line": 92, "start_col": 0, "end_col": 37, "parent_name": null, "docstring": "Standard convolution module with batch normalization and activation.\n\nAttributes:\n conv (nn.Conv2d): Convolutional layer.\n bn (nn.BatchNorm2d): Batch normalization layer.\n act (nn.Module): Activation function layer.\n default_act (nn.Module): Default activation function (SiLU).", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "math", "typing.List", "numpy", "torch", "torch.nn", "nn.Module" ], "chunk_id": "class_Conv_1bc938d1" }, { "content": "class Conv2(Conv):\n \"\"\"\n Simplified RepConv module with Conv fusing.\n\n Attributes:\n conv (nn.Conv2d): Main 3x3 convolutional layer.\n cv2 (nn.Conv2d): Additional 1x1 convolutional layer.\n bn (nn.BatchNorm2d): Batch normalization layer.\n act (nn.Module): Activation function layer.\n \"\"\"\n\n def __init__(self, c1, c2, k=3, s=1, p=None, g=1, d=1, act=True):\n \"\"\"\n Initialize Conv2 layer with given parameters.\n\n Args:\n c1 (int): Number of input channels.\n c2 (int): Number of output channels.\n k (int): Kernel size.\n s (int): Stride.\n p (int, optional): Padding.\n g (int): Groups.\n d (int): Dilation.\n act (bool | nn.Module): Activation function.\n \"\"\"\n super().__init__(c1, c2, k, s, p, g=g, d=d, act=act)\n self.cv2 = nn.Conv2d(c1, c2, 1, s, autopad(1, p, d), groups=g, dilation=d, bias=False) # add 1x1 conv\n\n def forward(self, x):\n \"\"\"\n Apply convolution, batch normalization and activation to input tensor.\n\n Args:\n x (torch.Tensor): Input tensor.\n\n Returns:\n (torch.Tensor): Output tensor.\n \"\"\"\n return self.act(self.bn(self.conv(x) + self.cv2(x)))\n\n def forward_fuse(self, x):\n \"\"\"\n Apply fused convolution, batch normalization and activation to input tensor.\n\n Args:\n x (torch.Tensor): Input tensor.\n\n Returns:\n (torch.Tensor): Output tensor.\n \"\"\"\n return self.act(self.bn(self.conv(x)))\n\n def fuse_convs(self):\n \"\"\"Fuse parallel convolutions.\"\"\"\n w = torch.zeros_like(self.conv.weight.data)\n i = [x // 2 for x in w.shape[2:]]\n w[:, :, i[0] : i[0] + 1, i[1] : i[1] + 1] = self.cv2.weight.data.clone()\n self.conv.weight.data += w\n self.__delattr__(\"cv2\")\n self.forward = self.forward_fuse", "chunk_type": "class", "name": "Conv2", "file_path": "ultralytics\\ultralytics\\nn\\modules\\conv.py", "start_line": 95, "end_line": 154, "start_col": 0, "end_col": 40, "parent_name": null, "docstring": "Simplified RepConv module with Conv fusing.\n\nAttributes:\n conv (nn.Conv2d): Main 3x3 convolutional layer.\n cv2 (nn.Conv2d): Additional 1x1 convolutional layer.\n bn (nn.BatchNorm2d): Batch normalization layer.\n act (nn.Module): Activation function layer.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "math", "typing.List", "numpy", "torch", "torch.nn", "Conv" ], "chunk_id": "class_Conv2_f2d55d14" }, { "content": "class LightConv(nn.Module):\n \"\"\"\n Light convolution module with 1x1 and depthwise convolutions.\n\n This implementation is based on the PaddleDetection HGNetV2 backbone.\n\n Attributes:\n conv1 (Conv): 1x1 convolution layer.\n conv2 (DWConv): Depthwise convolution layer.\n \"\"\"\n\n def __init__(self, c1, c2, k=1, act=nn.ReLU()):\n \"\"\"\n Initialize LightConv layer with given parameters.\n\n Args:\n c1 (int): Number of input channels.\n c2 (int): Number of output channels.\n k (int): Kernel size for depthwise convolution.\n act (nn.Module): Activation function.\n \"\"\"\n super().__init__()\n self.conv1 = Conv(c1, c2, 1, act=False)\n self.conv2 = DWConv(c2, c2, k, act=act)\n\n def forward(self, x):\n \"\"\"\n Apply 2 convolutions to input tensor.\n\n Args:\n x (torch.Tensor): Input tensor.\n\n Returns:\n (torch.Tensor): Output tensor.\n \"\"\"\n return self.conv2(self.conv1(x))", "chunk_type": "class", "name": "LightConv", "file_path": "ultralytics\\ultralytics\\nn\\modules\\conv.py", "start_line": 157, "end_line": 192, "start_col": 0, "end_col": 40, "parent_name": null, "docstring": "Light convolution module with 1x1 and depthwise convolutions.\n\nThis implementation is based on the PaddleDetection HGNetV2 backbone.\n\nAttributes:\n conv1 (Conv): 1x1 convolution layer.\n conv2 (DWConv): Depthwise convolution layer.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "math", "typing.List", "numpy", "torch", "torch.nn", "nn.Module" ], "chunk_id": "class_LightConv_f71a4c9f" }, { "content": "class DWConv(Conv):\n \"\"\"Depth-wise convolution module.\"\"\"\n\n def __init__(self, c1, c2, k=1, s=1, d=1, act=True):\n \"\"\"\n Initialize depth-wise convolution with given parameters.\n\n Args:\n c1 (int): Number of input channels.\n c2 (int): Number of output channels.\n k (int): Kernel size.\n s (int): Stride.\n d (int): Dilation.\n act (bool | nn.Module): Activation function.\n \"\"\"\n super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), d=d, act=act)", "chunk_type": "class", "name": "DWConv", "file_path": "ultralytics\\ultralytics\\nn\\modules\\conv.py", "start_line": 195, "end_line": 210, "start_col": 0, "end_col": 72, "parent_name": null, "docstring": "Depth-wise convolution module.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "math", "typing.List", "numpy", "torch", "torch.nn", "Conv" ], "chunk_id": "class_DWConv_e26260ac" }, { "content": "class DWConvTranspose2d(nn.ConvTranspose2d):\n \"\"\"Depth-wise transpose convolution module.\"\"\"\n\n def __init__(self, c1, c2, k=1, s=1, p1=0, p2=0):\n \"\"\"\n Initialize depth-wise transpose convolution with given parameters.\n\n Args:\n c1 (int): Number of input channels.\n c2 (int): Number of output channels.\n k (int): Kernel size.\n s (int): Stride.\n p1 (int): Padding.\n p2 (int): Output padding.\n \"\"\"\n super().__init__(c1, c2, k, s, p1, p2, groups=math.gcd(c1, c2))", "chunk_type": "class", "name": "DWConvTranspose2d", "file_path": "ultralytics\\ultralytics\\nn\\modules\\conv.py", "start_line": 213, "end_line": 228, "start_col": 0, "end_col": 71, "parent_name": null, "docstring": "Depth-wise transpose convolution module.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "math", "typing.List", "numpy", "torch", "torch.nn", "nn.ConvTranspose2d" ], "chunk_id": "class_DWConvTranspose2d_b683d7f6" }, { "content": "class ConvTranspose(nn.Module):\n \"\"\"\n Convolution transpose module with optional batch normalization and activation.\n\n Attributes:\n conv_transpose (nn.ConvTranspose2d): Transposed convolution layer.\n bn (nn.BatchNorm2d | nn.Identity): Batch normalization layer.\n act (nn.Module): Activation function layer.\n default_act (nn.Module): Default activation function (SiLU).\n \"\"\"\n\n default_act = nn.SiLU() # default activation\n\n def __init__(self, c1, c2, k=2, s=2, p=0, bn=True, act=True):\n \"\"\"\n Initialize ConvTranspose layer with given parameters.\n\n Args:\n c1 (int): Number of input channels.\n c2 (int): Number of output channels.\n k (int): Kernel size.\n s (int): Stride.\n p (int): Padding.\n bn (bool): Use batch normalization.\n act (bool | nn.Module): Activation function.\n \"\"\"\n super().__init__()\n self.conv_transpose = nn.ConvTranspose2d(c1, c2, k, s, p, bias=not bn)\n self.bn = nn.BatchNorm2d(c2) if bn else nn.Identity()\n self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()\n\n def forward(self, x):\n \"\"\"\n Apply transposed convolution, batch normalization and activation to input.\n\n Args:\n x (torch.Tensor): Input tensor.\n\n Returns:\n (torch.Tensor): Output tensor.\n \"\"\"\n return self.act(self.bn(self.conv_transpose(x)))\n\n def forward_fuse(self, x):\n \"\"\"\n Apply activation and convolution transpose operation to input.\n\n Args:\n x (torch.Tensor): Input tensor.\n\n Returns:\n (torch.Tensor): Output tensor.\n \"\"\"\n return self.act(self.conv_transpose(x))", "chunk_type": "class", "name": "ConvTranspose", "file_path": "ultralytics\\ultralytics\\nn\\modules\\conv.py", "start_line": 231, "end_line": 284, "start_col": 0, "end_col": 47, "parent_name": null, "docstring": "Convolution transpose module with optional batch normalization and activation.\n\nAttributes:\n conv_transpose (nn.ConvTranspose2d): Transposed convolution layer.\n bn (nn.BatchNorm2d | nn.Identity): Batch normalization layer.\n act (nn.Module): Activation function layer.\n default_act (nn.Module): Default activation function (SiLU).", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "math", "typing.List", "numpy", "torch", "torch.nn", "nn.Module" ], "chunk_id": "class_ConvTranspose_8c6636c1" }, { "content": "class Focus(nn.Module):\n \"\"\"\n Focus module for concentrating feature information.\n\n Slices input tensor into 4 parts and concatenates them in the channel dimension.\n\n Attributes:\n conv (Conv): Convolution layer.\n \"\"\"\n\n def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):\n \"\"\"\n Initialize Focus module with given parameters.\n\n Args:\n c1 (int): Number of input channels.\n c2 (int): Number of output channels.\n k (int): Kernel size.\n s (int): Stride.\n p (int, optional): Padding.\n g (int): Groups.\n act (bool | nn.Module): Activation function.\n \"\"\"\n super().__init__()\n self.conv = Conv(c1 * 4, c2, k, s, p, g, act=act)\n # self.contract = Contract(gain=2)\n\n def forward(self, x):\n \"\"\"\n Apply Focus operation and convolution to input tensor.\n\n Input shape is (B, C, W, H) and output shape is (B, 4C, W/2, H/2).\n\n Args:\n x (torch.Tensor): Input tensor.\n\n Returns:\n (torch.Tensor): Output tensor.\n \"\"\"\n return self.conv(torch.cat((x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]), 1))", "chunk_type": "class", "name": "Focus", "file_path": "ultralytics\\ultralytics\\nn\\modules\\conv.py", "start_line": 287, "end_line": 326, "start_col": 0, "end_col": 116, "parent_name": null, "docstring": "Focus module for concentrating feature information.\n\nSlices input tensor into 4 parts and concatenates them in the channel dimension.\n\nAttributes:\n conv (Conv): Convolution layer.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "math", "typing.List", "numpy", "torch", "torch.nn", "nn.Module" ], "chunk_id": "class_Focus_023c9a71" }, { "content": "class GhostConv(nn.Module):\n \"\"\"\n Ghost Convolution module.\n\n Generates more features with fewer parameters by using cheap operations.\n\n Attributes:\n cv1 (Conv): Primary convolution.\n cv2 (Conv): Cheap operation convolution.\n\n References:\n https://github.com/huawei-noah/Efficient-AI-Backbones\n \"\"\"\n\n def __init__(self, c1, c2, k=1, s=1, g=1, act=True):\n \"\"\"\n Initialize Ghost Convolution module with given parameters.\n\n Args:\n c1 (int): Number of input channels.\n c2 (int): Number of output channels.\n k (int): Kernel size.\n s (int): Stride.\n g (int): Groups.\n act (bool | nn.Module): Activation function.\n \"\"\"\n super().__init__()\n c_ = c2 // 2 # hidden channels\n self.cv1 = Conv(c1, c_, k, s, None, g, act=act)\n self.cv2 = Conv(c_, c_, 5, 1, None, c_, act=act)\n\n def forward(self, x):\n \"\"\"\n Apply Ghost Convolution to input tensor.\n\n Args:\n x (torch.Tensor): Input tensor.\n\n Returns:\n (torch.Tensor): Output tensor with concatenated features.\n \"\"\"\n y = self.cv1(x)\n return torch.cat((y, self.cv2(y)), 1)", "chunk_type": "class", "name": "GhostConv", "file_path": "ultralytics\\ultralytics\\nn\\modules\\conv.py", "start_line": 330, "end_line": 372, "start_col": 0, "end_col": 45, "parent_name": null, "docstring": "Ghost Convolution module.\n\nGenerates more features with fewer parameters by using cheap operations.\n\nAttributes:\n cv1 (Conv): Primary convolution.\n cv2 (Conv): Cheap operation convolution.\n\nReferences:\n https://github.com/huawei-noah/Efficient-AI-Backbones", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "math", "typing.List", "numpy", "torch", "torch.nn", "nn.Module" ], "chunk_id": "class_GhostConv_f69fae21" }, { "content": "class RepConv(nn.Module):\n \"\"\"\n RepConv module with training and deploy modes.\n\n This module is used in RT-DETR and can fuse convolutions during inference for efficiency.\n\n Attributes:\n conv1 (Conv): 3x3 convolution.\n conv2 (Conv): 1x1 convolution.\n bn (nn.BatchNorm2d, optional): Batch normalization for identity branch.\n act (nn.Module): Activation function.\n default_act (nn.Module): Default activation function (SiLU).\n\n References:\n https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py\n \"\"\"\n\n default_act = nn.SiLU() # default activation\n\n def __init__(self, c1, c2, k=3, s=1, p=1, g=1, d=1, act=True, bn=False, deploy=False):\n \"\"\"\n Initialize RepConv module with given parameters.\n\n Args:\n c1 (int): Number of input channels.\n c2 (int): Number of output channels.\n k (int): Kernel size.\n s (int): Stride.\n p (int): Padding.\n g (int): Groups.\n d (int): Dilation.\n act (bool | nn.Module): Activation function.\n bn (bool): Use batch normalization for identity branch.\n deploy (bool): Deploy mode for inference.\n \"\"\"\n super().__init__()\n assert k == 3 and p == 1\n self.g = g\n self.c1 = c1\n self.c2 = c2\n self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()\n\n self.bn = nn.BatchNorm2d(num_features=c1) if bn and c2 == c1 and s == 1 else None\n self.conv1 = Conv(c1, c2, k, s, p=p, g=g, act=False)\n self.conv2 = Conv(c1, c2, 1, s, p=(p - k // 2), g=g, act=False)\n\n def forward_fuse(self, x):\n \"\"\"\n Forward pass for deploy mode.\n\n Args:\n x (torch.Tensor): Input tensor.\n\n Returns:\n (torch.Tensor): Output tensor.\n \"\"\"\n return self.act(self.conv(x))\n\n def forward(self, x):\n \"\"\"\n Forward pass for training mode.\n\n Args:\n x (torch.Tensor): Input tensor.\n\n Returns:\n (torch.Tensor): Output tensor.\n \"\"\"\n id_out = 0 if self.bn is None else self.bn(x)\n return self.act(self.conv1(x) + self.conv2(x) + id_out)\n\n def get_equivalent_kernel_bias(self):\n \"\"\"\n Calculate equivalent kernel and bias by fusing convolutions.\n\n Returns:\n (torch.Tensor): Equivalent kernel\n (torch.Tensor): Equivalent bias\n \"\"\"\n kernel3x3, bias3x3 = self._fuse_bn_tensor(self.conv1)\n kernel1x1, bias1x1 = self._fuse_bn_tensor(self.conv2)\n kernelid, biasid = self._fuse_bn_tensor(self.bn)\n return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid\n\n @staticmethod\n def _pad_1x1_to_3x3_tensor(kernel1x1):\n \"\"\"\n Pad a 1x1 kernel to 3x3 size.\n\n Args:\n kernel1x1 (torch.Tensor): 1x1 convolution kernel.\n\n Returns:\n (torch.Tensor): Padded 3x3 kernel.\n \"\"\"\n if kernel1x1 is None:\n return 0\n else:\n return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1])\n\n def _fuse_bn_tensor(self, branch):\n \"\"\"\n Fuse batch normalization with convolution weights.\n\n Args:\n branch (Conv | nn.BatchNorm2d | None): Branch to fuse.\n\n Returns:\n kernel (torch.Tensor): Fused kernel.\n bias (torch.Tensor): Fused bias.\n \"\"\"\n if branch is None:\n return 0, 0\n if isinstance(branch, Conv):\n kernel = branch.conv.weight\n running_mean = branch.bn.running_mean\n running_var = branch.bn.running_var\n gamma = branch.bn.weight\n beta = branch.bn.bias\n eps = branch.bn.eps\n elif isinstance(branch, nn.BatchNorm2d):\n if not hasattr(self, \"id_tensor\"):\n input_dim = self.c1 // self.g\n kernel_value = np.zeros((self.c1, input_dim, 3, 3), dtype=np.float32)\n for i in range(self.c1):\n kernel_value[i, i % input_dim, 1, 1] = 1\n self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)\n kernel = self.id_tensor\n running_mean = branch.running_mean\n running_var = branch.running_var\n gamma = branch.weight\n beta = branch.bias\n eps = branch.eps\n std = (running_var + eps).sqrt()\n t = (gamma / std).reshape(-1, 1, 1, 1)\n return kernel * t, beta - running_mean * gamma / std\n\n def fuse_convs(self):\n \"\"\"Fuse convolutions for inference by creating a single equivalent convolution.\"\"\"\n if hasattr(self, \"conv\"):\n return\n kernel, bias = self.get_equivalent_kernel_bias()\n self.conv = nn.Conv2d(\n in_channels=self.conv1.conv.in_channels,\n out_channels=self.conv1.conv.out_channels,\n kernel_size=self.conv1.conv.kernel_size,\n stride=self.conv1.conv.stride,\n padding=self.conv1.conv.padding,\n dilation=self.conv1.conv.dilation,\n groups=self.conv1.conv.groups,\n bias=True,\n ).requires_grad_(False)\n self.conv.weight.data = kernel\n self.conv.bias.data = bias\n for para in self.parameters():\n para.detach_()\n self.__delattr__(\"conv1\")\n self.__delattr__(\"conv2\")\n if hasattr(self, \"nm\"):\n self.__delattr__(\"nm\")\n if hasattr(self, \"bn\"):\n self.__delattr__(\"bn\")\n if hasattr(self, \"id_tensor\"):\n self.__delattr__(\"id_tensor\")", "chunk_type": "class", "name": "RepConv", "file_path": "ultralytics\\ultralytics\\nn\\modules\\conv.py", "start_line": 375, "end_line": 538, "start_col": 0, "end_col": 41, "parent_name": null, "docstring": "RepConv module with training and deploy modes.\n\nThis module is used in RT-DETR and can fuse convolutions during inference for efficiency.\n\nAttributes:\n conv1 (Conv): 3x3 convolution.\n conv2 (Conv): 1x1 convolution.\n bn (nn.BatchNorm2d, optional): Batch normalization for identity branch.\n act (nn.Module): Activation function.\n default_act (nn.Module): Default activation function (SiLU).\n\nReferences:\n https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "math", "typing.List", "numpy", "torch", "torch.nn", "nn.Module" ], "chunk_id": "class_RepConv_d1e80245" }, { "content": "class ChannelAttention(nn.Module):\n \"\"\"\n Channel-attention module for feature recalibration.\n\n Applies attention weights to channels based on global average pooling.\n\n Attributes:\n pool (nn.AdaptiveAvgPool2d): Global average pooling.\n fc (nn.Conv2d): Fully connected layer implemented as 1x1 convolution.\n act (nn.Sigmoid): Sigmoid activation for attention weights.\n\n References:\n https://github.com/open-mmlab/mmdetection/tree/v3.0.0rc1/configs/rtmdet\n \"\"\"\n\n def __init__(self, channels: int) -> None:\n \"\"\"\n Initialize Channel-attention module.\n\n Args:\n channels (int): Number of input channels.\n \"\"\"\n super().__init__()\n self.pool = nn.AdaptiveAvgPool2d(1)\n self.fc = nn.Conv2d(channels, channels, 1, 1, 0, bias=True)\n self.act = nn.Sigmoid()\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"\n Apply channel attention to input tensor.\n\n Args:\n x (torch.Tensor): Input tensor.\n\n Returns:\n (torch.Tensor): Channel-attended output tensor.\n \"\"\"\n return x * self.act(self.fc(self.pool(x)))", "chunk_type": "class", "name": "ChannelAttention", "file_path": "ultralytics\\ultralytics\\nn\\modules\\conv.py", "start_line": 541, "end_line": 578, "start_col": 0, "end_col": 50, "parent_name": null, "docstring": "Channel-attention module for feature recalibration.\n\nApplies attention weights to channels based on global average pooling.\n\nAttributes:\n pool (nn.AdaptiveAvgPool2d): Global average pooling.\n fc (nn.Conv2d): Fully connected layer implemented as 1x1 convolution.\n act (nn.Sigmoid): Sigmoid activation for attention weights.\n\nReferences:\n https://github.com/open-mmlab/mmdetection/tree/v3.0.0rc1/configs/rtmdet", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "math", "typing.List", "numpy", "torch", "torch.nn", "nn.Module" ], "chunk_id": "class_ChannelAttention_0d43909d" }, { "content": "class SpatialAttention(nn.Module):\n \"\"\"\n Spatial-attention module for feature recalibration.\n\n Applies attention weights to spatial dimensions based on channel statistics.\n\n Attributes:\n cv1 (nn.Conv2d): Convolution layer for spatial attention.\n act (nn.Sigmoid): Sigmoid activation for attention weights.\n \"\"\"\n\n def __init__(self, kernel_size=7):\n \"\"\"\n Initialize Spatial-attention module.\n\n Args:\n kernel_size (int): Size of the convolutional kernel (3 or 7).\n \"\"\"\n super().__init__()\n assert kernel_size in {3, 7}, \"kernel size must be 3 or 7\"\n padding = 3 if kernel_size == 7 else 1\n self.cv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)\n self.act = nn.Sigmoid()\n\n def forward(self, x):\n \"\"\"\n Apply spatial attention to input tensor.\n\n Args:\n x (torch.Tensor): Input tensor.\n\n Returns:\n (torch.Tensor): Spatial-attended output tensor.\n \"\"\"\n return x * self.act(self.cv1(torch.cat([torch.mean(x, 1, keepdim=True), torch.max(x, 1, keepdim=True)[0]], 1)))", "chunk_type": "class", "name": "SpatialAttention", "file_path": "ultralytics\\ultralytics\\nn\\modules\\conv.py", "start_line": 581, "end_line": 615, "start_col": 0, "end_col": 119, "parent_name": null, "docstring": "Spatial-attention module for feature recalibration.\n\nApplies attention weights to spatial dimensions based on channel statistics.\n\nAttributes:\n cv1 (nn.Conv2d): Convolution layer for spatial attention.\n act (nn.Sigmoid): Sigmoid activation for attention weights.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "math", "typing.List", "numpy", "torch", "torch.nn", "nn.Module" ], "chunk_id": "class_SpatialAttention_57c216fb" }, { "content": "class CBAM(nn.Module):\n \"\"\"\n Convolutional Block Attention Module.\n\n Combines channel and spatial attention mechanisms for comprehensive feature refinement.\n\n Attributes:\n channel_attention (ChannelAttention): Channel attention module.\n spatial_attention (SpatialAttention): Spatial attention module.\n \"\"\"\n\n def __init__(self, c1, kernel_size=7):\n \"\"\"\n Initialize CBAM with given parameters.\n\n Args:\n c1 (int): Number of input channels.\n kernel_size (int): Size of the convolutional kernel for spatial attention.\n \"\"\"\n super().__init__()\n self.channel_attention = ChannelAttention(c1)\n self.spatial_attention = SpatialAttention(kernel_size)\n\n def forward(self, x):\n \"\"\"\n Apply channel and spatial attention sequentially to input tensor.\n\n Args:\n x (torch.Tensor): Input tensor.\n\n Returns:\n (torch.Tensor): Attended output tensor.\n \"\"\"\n return self.spatial_attention(self.channel_attention(x))", "chunk_type": "class", "name": "CBAM", "file_path": "ultralytics\\ultralytics\\nn\\modules\\conv.py", "start_line": 618, "end_line": 651, "start_col": 0, "end_col": 64, "parent_name": null, "docstring": "Convolutional Block Attention Module.\n\nCombines channel and spatial attention mechanisms for comprehensive feature refinement.\n\nAttributes:\n channel_attention (ChannelAttention): Channel attention module.\n spatial_attention (SpatialAttention): Spatial attention module.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "math", "typing.List", "numpy", "torch", "torch.nn", "nn.Module" ], "chunk_id": "class_CBAM_174201a4" }, { "content": "class Concat(nn.Module):\n \"\"\"\n Concatenate a list of tensors along specified dimension.\n\n Attributes:\n d (int): Dimension along which to concatenate tensors.\n \"\"\"\n\n def __init__(self, dimension=1):\n \"\"\"\n Initialize Concat module.\n\n Args:\n dimension (int): Dimension along which to concatenate tensors.\n \"\"\"\n super().__init__()\n self.d = dimension\n\n def forward(self, x: List[torch.Tensor]):\n \"\"\"\n Concatenate input tensors along specified dimension.\n\n Args:\n x (List[torch.Tensor]): List of input tensors.\n\n Returns:\n (torch.Tensor): Concatenated tensor.\n \"\"\"\n return torch.cat(x, self.d)", "chunk_type": "class", "name": "Concat", "file_path": "ultralytics\\ultralytics\\nn\\modules\\conv.py", "start_line": 654, "end_line": 682, "start_col": 0, "end_col": 35, "parent_name": null, "docstring": "Concatenate a list of tensors along specified dimension.\n\nAttributes:\n d (int): Dimension along which to concatenate tensors.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "math", "typing.List", "numpy", "torch", "torch.nn", "nn.Module" ], "chunk_id": "class_Concat_89f2ccf2" }, { "content": "class Index(nn.Module):\n \"\"\"\n Returns a particular index of the input.\n\n Attributes:\n index (int): Index to select from input.\n \"\"\"\n\n def __init__(self, index=0):\n \"\"\"\n Initialize Index module.\n\n Args:\n index (int): Index to select from input.\n \"\"\"\n super().__init__()\n self.index = index\n\n def forward(self, x: List[torch.Tensor]):\n \"\"\"\n Select and return a particular index from input.\n\n Args:\n x (List[torch.Tensor]): List of input tensors.\n\n Returns:\n (torch.Tensor): Selected tensor.\n \"\"\"\n return x[self.index]", "chunk_type": "class", "name": "Index", "file_path": "ultralytics\\ultralytics\\nn\\modules\\conv.py", "start_line": 685, "end_line": 713, "start_col": 0, "end_col": 28, "parent_name": null, "docstring": "Returns a particular index of the input.\n\nAttributes:\n index (int): Index to select from input.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "math", "typing.List", "numpy", "torch", "torch.nn", "nn.Module" ], "chunk_id": "class_Index_72f4bfca" }, { "content": "import copy", "chunk_type": "import", "name": "copy", "file_path": "ultralytics\\ultralytics\\nn\\modules\\head.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 11, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_copy_db5e079d" }, { "content": "import math", "chunk_type": "import", "name": "math", "file_path": "ultralytics\\ultralytics\\nn\\modules\\head.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 11, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_math_3d4ed46a" }, { "content": "from typing import List, Optional, Tuple, Union", "chunk_type": "import", "name": "List, Optional, Tuple, Union", "file_path": "ultralytics\\ultralytics\\nn\\modules\\head.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 47, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_List, Optional, Tuple, Union_7be0b002" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\ultralytics\\nn\\modules\\head.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_2a9f0f53" }, { "content": "import torch.nn as nn", "chunk_type": "import", "name": "torch.nn", "file_path": "ultralytics\\ultralytics\\nn\\modules\\head.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 21, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch.nn_51b0a839" }, { "content": "import torch.nn.functional as F", "chunk_type": "import", "name": "torch.nn.functional", "file_path": "ultralytics\\ultralytics\\nn\\modules\\head.py", "start_line": 10, "end_line": 10, "start_col": 0, "end_col": 31, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch.nn.functional_a2c8732e" }, { "content": "from torch.nn.init import constant_, xavier_uniform_", "chunk_type": "import", "name": "constant_, xavier_uniform_", "file_path": "ultralytics\\ultralytics\\nn\\modules\\head.py", "start_line": 11, "end_line": 11, "start_col": 0, "end_col": 52, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_constant_, xavier_uniform__dd452c3d" }, { "content": "from ultralytics.utils.tal import TORCH_1_10, dist2bbox, dist2rbox, make_anchors", "chunk_type": "import", "name": "TORCH_1_10, dist2bbox, dist2rbox, make_anchors", "file_path": "ultralytics\\ultralytics\\nn\\modules\\head.py", "start_line": 13, "end_line": 13, "start_col": 0, "end_col": 80, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_TORCH_1_10, dist2bbox, dist2rbox, make_anchors_8e915633" }, { "content": "from ultralytics.utils.torch_utils import fuse_conv_and_bn, smart_inference_mode", "chunk_type": "import", "name": "fuse_conv_and_bn, smart_inference_mode", "file_path": "ultralytics\\ultralytics\\nn\\modules\\head.py", "start_line": 14, "end_line": 14, "start_col": 0, "end_col": 80, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_fuse_conv_and_bn, smart_inference_mode_5e4099c9" }, { "content": "from .block import DFL, SAVPE, BNContrastiveHead, ContrastiveHead, Proto, Residual, SwiGLUFFN", "chunk_type": "import", "name": "DFL, SAVPE, BNContrastiveHead, ContrastiveHead, Proto, Residual, SwiGLUFFN", "file_path": "ultralytics\\ultralytics\\nn\\modules\\head.py", "start_line": 16, "end_line": 16, "start_col": 0, "end_col": 93, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_DFL, SAVPE, BNContrastiveHead, ContrastiveHead, Proto, Residual, SwiGLUFFN_7e717db0" }, { "content": "from .conv import Conv, DWConv", "chunk_type": "import", "name": "Conv, DWConv", "file_path": "ultralytics\\ultralytics\\nn\\modules\\head.py", "start_line": 17, "end_line": 17, "start_col": 0, "end_col": 30, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Conv, DWConv_163d563d" }, { "content": "from .transformer import MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer", "chunk_type": "import", "name": "MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer", "file_path": "ultralytics\\ultralytics\\nn\\modules\\head.py", "start_line": 18, "end_line": 18, "start_col": 0, "end_col": 93, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer_e8a853e2" }, { "content": "from .utils import bias_init_with_prob, linear_init", "chunk_type": "import", "name": "bias_init_with_prob, linear_init", "file_path": "ultralytics\\ultralytics\\nn\\modules\\head.py", "start_line": 19, "end_line": 19, "start_col": 0, "end_col": 51, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_bias_init_with_prob, linear_init_f82d045a" }, { "content": "__all__ = \"Detect\", \"Segment\", \"Pose\", \"Classify\", \"OBB\", \"RTDETRDecoder\", \"v10Detect\", \"YOLOEDetect\", \"YOLOESegment\"", "chunk_type": "variable", "name": "__all__", "file_path": "ultralytics\\ultralytics\\nn\\modules\\head.py", "start_line": 21, "end_line": 21, "start_col": 0, "end_col": 117, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable___all___04011789" }, { "content": "class Detect(nn.Module):\n \"\"\"\n YOLO Detect head for object detection models.\n\n This class implements the detection head used in YOLO models for predicting bounding boxes and class probabilities.\n It supports both training and inference modes, with optional end-to-end detection capabilities.\n\n Attributes:\n dynamic (bool): Force grid reconstruction.\n export (bool): Export mode flag.\n format (str): Export format.\n end2end (bool): End-to-end detection mode.\n max_det (int): Maximum detections per image.\n shape (tuple): Input shape.\n anchors (torch.Tensor): Anchor points.\n strides (torch.Tensor): Feature map strides.\n legacy (bool): Backward compatibility for v3/v5/v8/v9 models.\n xyxy (bool): Output format, xyxy or xywh.\n nc (int): Number of classes.\n nl (int): Number of detection layers.\n reg_max (int): DFL channels.\n no (int): Number of outputs per anchor.\n stride (torch.Tensor): Strides computed during build.\n cv2 (nn.ModuleList): Convolution layers for box regression.\n cv3 (nn.ModuleList): Convolution layers for classification.\n dfl (nn.Module): Distribution Focal Loss layer.\n one2one_cv2 (nn.ModuleList): One-to-one convolution layers for box regression.\n one2one_cv3 (nn.ModuleList): One-to-one convolution layers for classification.\n\n Methods:\n forward: Perform forward pass and return predictions.\n forward_end2end: Perform forward pass for end-to-end detection.\n bias_init: Initialize detection head biases.\n decode_bboxes: Decode bounding boxes from predictions.\n postprocess: Post-process model predictions.\n\n Examples:\n Create a detection head for 80 classes\n >>> detect = Detect(nc=80, ch=(256, 512, 1024))\n >>> x = [torch.randn(1, 256, 80, 80), torch.randn(1, 512, 40, 40), torch.randn(1, 1024, 20, 20)]\n >>> outputs = detect(x)\n \"\"\"\n\n dynamic = False # force grid reconstruction\n export = False # export mode\n format = None # export format\n end2end = False # end2end\n max_det = 300 # max_det\n shape = None\n anchors = torch.empty(0) # init\n strides = torch.empty(0) # init\n legacy = False # backward compatibility for v3/v5/v8/v9 models\n xyxy = False # xyxy or xywh output\n\n def __init__(self, nc: int = 80, ch: Tuple = ()):\n \"\"\"\n Initialize the YOLO detection layer with specified number of classes and channels.\n\n Args:\n nc (int): Number of classes.\n ch (tuple): Tuple of channel sizes from backbone feature maps.\n \"\"\"\n super().__init__()\n self.nc = nc # number of classes\n self.nl = len(ch) # number of detection layers\n self.reg_max = 16 # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x)\n self.no = nc + self.reg_max * 4 # number of outputs per anchor\n self.stride = torch.zeros(self.nl) # strides computed during build\n c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100)) # channels\n self.cv2 = nn.ModuleList(\n nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch\n )\n self.cv3 = (\n nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch)\n if self.legacy\n else nn.ModuleList(\n nn.Sequential(\n nn.Sequential(DWConv(x, x, 3), Conv(x, c3, 1)),\n nn.Sequential(DWConv(c3, c3, 3), Conv(c3, c3, 1)),\n nn.Conv2d(c3, self.nc, 1),\n )\n for x in ch\n )\n )\n self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()\n\n if self.end2end:\n self.one2one_cv2 = copy.deepcopy(self.cv2)\n self.one2one_cv3 = copy.deepcopy(self.cv3)\n\n def forward(self, x: List[torch.Tensor]) -> Union[List[torch.Tensor], Tuple]:\n \"\"\"Concatenate and return predicted bounding boxes and class probabilities.\"\"\"\n if self.end2end:\n return self.forward_end2end(x)\n\n for i in range(self.nl):\n x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)\n if self.training: # Training path\n return x\n y = self._inference(x)\n return y if self.export else (y, x)\n\n def forward_end2end(self, x: List[torch.Tensor]) -> Union[dict, Tuple]:\n \"\"\"\n Perform forward pass of the v10Detect module.\n\n Args:\n x (List[torch.Tensor]): Input feature maps from different levels.\n\n Returns:\n outputs (dict | tuple): Training mode returns dict with one2many and one2one outputs.\n Inference mode returns processed detections or tuple with detections and raw outputs.\n \"\"\"\n x_detach = [xi.detach() for xi in x]\n one2one = [\n torch.cat((self.one2one_cv2[i](x_detach[i]), self.one2one_cv3[i](x_detach[i])), 1) for i in range(self.nl)\n ]\n for i in range(self.nl):\n x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)\n if self.training: # Training path\n return {\"one2many\": x, \"one2one\": one2one}\n\n y = self._inference(one2one)\n y = self.postprocess(y.permute(0, 2, 1), self.max_det, self.nc)\n return y if self.export else (y, {\"one2many\": x, \"one2one\": one2one})\n\n def _inference(self, x: List[torch.Tensor]) -> torch.Tensor:\n \"\"\"\n Decode predicted bounding boxes and class probabilities based on multiple-level feature maps.\n\n Args:\n x (List[torch.Tensor]): List of feature maps from different detection layers.\n\n Returns:\n (torch.Tensor): Concatenated tensor of decoded bounding boxes and class probabilities.\n \"\"\"\n # Inference path\n shape = x[0].shape # BCHW\n x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)\n if self.format != \"imx\" and (self.dynamic or self.shape != shape):\n self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))\n self.shape = shape\n\n if self.export and self.format in {\"saved_model\", \"pb\", \"tflite\", \"edgetpu\", \"tfjs\"}: # avoid TF FlexSplitV ops\n box = x_cat[:, : self.reg_max * 4]\n cls = x_cat[:, self.reg_max * 4 :]\n else:\n box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)\n\n if self.export and self.format in {\"tflite\", \"edgetpu\"}:\n # Precompute normalization factor to increase numerical stability\n # See https://github.com/ultralytics/ultralytics/issues/7371\n grid_h = shape[2]\n grid_w = shape[3]\n grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=box.device).reshape(1, 4, 1)\n norm = self.strides / (self.stride[0] * grid_size)\n dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2])\n else:\n dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides\n if self.export and self.format == \"imx\":\n return dbox.transpose(1, 2), cls.sigmoid().permute(0, 2, 1)\n return torch.cat((dbox, cls.sigmoid()), 1)\n\n def bias_init(self):\n \"\"\"Initialize Detect() biases, WARNING: requires stride availability.\"\"\"\n m = self # self.model[-1] # Detect() module\n # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1\n # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency\n for a, b, s in zip(m.cv2, m.cv3, m.stride): # from\n a[-1].bias.data[:] = 1.0 # box\n b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)\n if self.end2end:\n for a, b, s in zip(m.one2one_cv2, m.one2one_cv3, m.stride): # from\n a[-1].bias.data[:] = 1.0 # box\n b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)\n\n def decode_bboxes(self, bboxes: torch.Tensor, anchors: torch.Tensor, xywh: bool = True) -> torch.Tensor:\n \"\"\"Decode bounding boxes from predictions.\"\"\"\n return dist2bbox(bboxes, anchors, xywh=xywh and not (self.end2end or self.xyxy), dim=1)\n\n @staticmethod\n def postprocess(preds: torch.Tensor, max_det: int, nc: int = 80) -> torch.Tensor:\n \"\"\"\n Post-process YOLO model predictions.\n\n Args:\n preds (torch.Tensor): Raw predictions with shape (batch_size, num_anchors, 4 + nc) with last dimension\n format [x, y, w, h, class_probs].\n max_det (int): Maximum detections per image.\n nc (int, optional): Number of classes.\n\n Returns:\n (torch.Tensor): Processed predictions with shape (batch_size, min(max_det, num_anchors), 6) and last\n dimension format [x, y, w, h, max_class_prob, class_index].\n \"\"\"\n batch_size, anchors, _ = preds.shape # i.e. shape(16,8400,84)\n boxes, scores = preds.split([4, nc], dim=-1)\n index = scores.amax(dim=-1).topk(min(max_det, anchors))[1].unsqueeze(-1)\n boxes = boxes.gather(dim=1, index=index.repeat(1, 1, 4))\n scores = scores.gather(dim=1, index=index.repeat(1, 1, nc))\n scores, index = scores.flatten(1).topk(min(max_det, anchors))\n i = torch.arange(batch_size)[..., None] # batch indices\n return torch.cat([boxes[i, index // nc], scores[..., None], (index % nc)[..., None].float()], dim=-1)", "chunk_type": "class", "name": "Detect", "file_path": "ultralytics\\ultralytics\\nn\\modules\\head.py", "start_line": 24, "end_line": 226, "start_col": 0, "end_col": 109, "parent_name": null, "docstring": "YOLO Detect head for object detection models.\n\nThis class implements the detection head used in YOLO models for predicting bounding boxes and class probabilities.\nIt supports both training and inference modes, with optional end-to-end detection capabilities.\n\nAttributes:\n dynamic (bool): Force grid reconstruction.\n export (bool): Export mode flag.\n format (str): Export format.\n end2end (bool): End-to-end detection mode.\n max_det (int): Maximum detections per image.\n shape (tuple): Input shape.\n anchors (torch.Tensor): Anchor points.\n strides (torch.Tensor): Feature map strides.\n legacy (bool): Backward compatibility for v3/v5/v8/v9 models.\n xyxy (bool): Output format, xyxy or xywh.\n nc (int): Number of classes.\n nl (int): Number of detection layers.\n reg_max (int): DFL channels.\n no (int): Number of outputs per anchor.\n stride (torch.Tensor): Strides computed during build.\n cv2 (nn.ModuleList): Convolution layers for box regression.\n cv3 (nn.ModuleList): Convolution layers for classification.\n dfl (nn.Module): Distribution Focal Loss layer.\n one2one_cv2 (nn.ModuleList): One-to-one convolution layers for box regression.\n one2one_cv3 (nn.ModuleList): One-to-one convolution layers for classification.\n\nMethods:\n forward: Perform forward pass and return predictions.\n forward_end2end: Perform forward pass for end-to-end detection.\n bias_init: Initialize detection head biases.\n decode_bboxes: Decode bounding boxes from predictions.\n postprocess: Post-process model predictions.\n\nExamples:\n Create a detection head for 80 classes\n >>> detect = Detect(nc=80, ch=(256, 512, 1024))\n >>> x = [torch.randn(1, 256, 80, 80), torch.randn(1, 512, 40, 40), torch.randn(1, 1024, 20, 20)]\n >>> outputs = detect(x)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "copy", "math", "typing.List", "typing.Optional", "typing.Tuple", "typing.Union", "torch", "torch.nn", "torch.nn.functional", "torch.nn.init.constant_", "torch.nn.init.xavier_uniform_", "ultralytics.utils.tal.TORCH_1_10", "ultralytics.utils.tal.dist2bbox", "ultralytics.utils.tal.dist2rbox", "ultralytics.utils.tal.make_anchors", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "ultralytics.utils.torch_utils.smart_inference_mode", "block.DFL", "block.SAVPE", "block.BNContrastiveHead", "block.ContrastiveHead", "block.Proto", "block.Residual", "block.SwiGLUFFN", "conv.Conv", "conv.DWConv", "transformer.MLP", "transformer.DeformableTransformerDecoder", "transformer.DeformableTransformerDecoderLayer", "utils.bias_init_with_prob", "utils.linear_init", "ultralytics.models.utils.ops.get_cdn_group", "nn.Module" ], "chunk_id": "class_Detect_72afa236" }, { "content": "class Segment(Detect):\n \"\"\"\n YOLO Segment head for segmentation models.\n\n This class extends the Detect head to include mask prediction capabilities for instance segmentation tasks.\n\n Attributes:\n nm (int): Number of masks.\n npr (int): Number of protos.\n proto (Proto): Prototype generation module.\n cv4 (nn.ModuleList): Convolution layers for mask coefficients.\n\n Methods:\n forward: Return model outputs and mask coefficients.\n\n Examples:\n Create a segmentation head\n >>> segment = Segment(nc=80, nm=32, npr=256, ch=(256, 512, 1024))\n >>> x = [torch.randn(1, 256, 80, 80), torch.randn(1, 512, 40, 40), torch.randn(1, 1024, 20, 20)]\n >>> outputs = segment(x)\n \"\"\"\n\n def __init__(self, nc: int = 80, nm: int = 32, npr: int = 256, ch: Tuple = ()):\n \"\"\"\n Initialize the YOLO model attributes such as the number of masks, prototypes, and the convolution layers.\n\n Args:\n nc (int): Number of classes.\n nm (int): Number of masks.\n npr (int): Number of protos.\n ch (tuple): Tuple of channel sizes from backbone feature maps.\n \"\"\"\n super().__init__(nc, ch)\n self.nm = nm # number of masks\n self.npr = npr # number of protos\n self.proto = Proto(ch[0], self.npr, self.nm) # protos\n\n c4 = max(ch[0] // 4, self.nm)\n self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nm, 1)) for x in ch)\n\n def forward(self, x: List[torch.Tensor]) -> Union[Tuple, List[torch.Tensor]]:\n \"\"\"Return model outputs and mask coefficients if training, otherwise return outputs and mask coefficients.\"\"\"\n p = self.proto(x[0]) # mask protos\n bs = p.shape[0] # batch size\n\n mc = torch.cat([self.cv4[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2) # mask coefficients\n x = Detect.forward(self, x)\n if self.training:\n return x, mc, p\n return (torch.cat([x, mc], 1), p) if self.export else (torch.cat([x[0], mc], 1), (x[1], mc, p))", "chunk_type": "class", "name": "Segment", "file_path": "ultralytics\\ultralytics\\nn\\modules\\head.py", "start_line": 229, "end_line": 278, "start_col": 0, "end_col": 103, "parent_name": null, "docstring": "YOLO Segment head for segmentation models.\n\nThis class extends the Detect head to include mask prediction capabilities for instance segmentation tasks.\n\nAttributes:\n nm (int): Number of masks.\n npr (int): Number of protos.\n proto (Proto): Prototype generation module.\n cv4 (nn.ModuleList): Convolution layers for mask coefficients.\n\nMethods:\n forward: Return model outputs and mask coefficients.\n\nExamples:\n Create a segmentation head\n >>> segment = Segment(nc=80, nm=32, npr=256, ch=(256, 512, 1024))\n >>> x = [torch.randn(1, 256, 80, 80), torch.randn(1, 512, 40, 40), torch.randn(1, 1024, 20, 20)]\n >>> outputs = segment(x)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "copy", "math", "typing.List", "typing.Optional", "typing.Tuple", "typing.Union", "torch", "torch.nn", "torch.nn.functional", "torch.nn.init.constant_", "torch.nn.init.xavier_uniform_", "ultralytics.utils.tal.TORCH_1_10", "ultralytics.utils.tal.dist2bbox", "ultralytics.utils.tal.dist2rbox", "ultralytics.utils.tal.make_anchors", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "ultralytics.utils.torch_utils.smart_inference_mode", "block.DFL", "block.SAVPE", "block.BNContrastiveHead", "block.ContrastiveHead", "block.Proto", "block.Residual", "block.SwiGLUFFN", "conv.Conv", "conv.DWConv", "transformer.MLP", "transformer.DeformableTransformerDecoder", "transformer.DeformableTransformerDecoderLayer", "utils.bias_init_with_prob", "utils.linear_init", "ultralytics.models.utils.ops.get_cdn_group", "Detect" ], "chunk_id": "class_Segment_477a5588" }, { "content": "class OBB(Detect):\n \"\"\"\n YOLO OBB detection head for detection with rotation models.\n\n This class extends the Detect head to include oriented bounding box prediction with rotation angles.\n\n Attributes:\n ne (int): Number of extra parameters.\n cv4 (nn.ModuleList): Convolution layers for angle prediction.\n angle (torch.Tensor): Predicted rotation angles.\n\n Methods:\n forward: Concatenate and return predicted bounding boxes and class probabilities.\n decode_bboxes: Decode rotated bounding boxes.\n\n Examples:\n Create an OBB detection head\n >>> obb = OBB(nc=80, ne=1, ch=(256, 512, 1024))\n >>> x = [torch.randn(1, 256, 80, 80), torch.randn(1, 512, 40, 40), torch.randn(1, 1024, 20, 20)]\n >>> outputs = obb(x)\n \"\"\"\n\n def __init__(self, nc: int = 80, ne: int = 1, ch: Tuple = ()):\n \"\"\"\n Initialize OBB with number of classes `nc` and layer channels `ch`.\n\n Args:\n nc (int): Number of classes.\n ne (int): Number of extra parameters.\n ch (tuple): Tuple of channel sizes from backbone feature maps.\n \"\"\"\n super().__init__(nc, ch)\n self.ne = ne # number of extra parameters\n\n c4 = max(ch[0] // 4, self.ne)\n self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.ne, 1)) for x in ch)\n\n def forward(self, x: List[torch.Tensor]) -> Union[torch.Tensor, Tuple]:\n \"\"\"Concatenate and return predicted bounding boxes and class probabilities.\"\"\"\n bs = x[0].shape[0] # batch size\n angle = torch.cat([self.cv4[i](x[i]).view(bs, self.ne, -1) for i in range(self.nl)], 2) # OBB theta logits\n # NOTE: set `angle` as an attribute so that `decode_bboxes` could use it.\n angle = (angle.sigmoid() - 0.25) * math.pi # [-pi/4, 3pi/4]\n # angle = angle.sigmoid() * math.pi / 2 # [0, pi/2]\n if not self.training:\n self.angle = angle\n x = Detect.forward(self, x)\n if self.training:\n return x, angle\n return torch.cat([x, angle], 1) if self.export else (torch.cat([x[0], angle], 1), (x[1], angle))\n\n def decode_bboxes(self, bboxes: torch.Tensor, anchors: torch.Tensor) -> torch.Tensor:\n \"\"\"Decode rotated bounding boxes.\"\"\"\n return dist2rbox(bboxes, self.angle, anchors, dim=1)", "chunk_type": "class", "name": "OBB", "file_path": "ultralytics\\ultralytics\\nn\\modules\\head.py", "start_line": 281, "end_line": 334, "start_col": 0, "end_col": 60, "parent_name": null, "docstring": "YOLO OBB detection head for detection with rotation models.\n\nThis class extends the Detect head to include oriented bounding box prediction with rotation angles.\n\nAttributes:\n ne (int): Number of extra parameters.\n cv4 (nn.ModuleList): Convolution layers for angle prediction.\n angle (torch.Tensor): Predicted rotation angles.\n\nMethods:\n forward: Concatenate and return predicted bounding boxes and class probabilities.\n decode_bboxes: Decode rotated bounding boxes.\n\nExamples:\n Create an OBB detection head\n >>> obb = OBB(nc=80, ne=1, ch=(256, 512, 1024))\n >>> x = [torch.randn(1, 256, 80, 80), torch.randn(1, 512, 40, 40), torch.randn(1, 1024, 20, 20)]\n >>> outputs = obb(x)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "copy", "math", "typing.List", "typing.Optional", "typing.Tuple", "typing.Union", "torch", "torch.nn", "torch.nn.functional", "torch.nn.init.constant_", "torch.nn.init.xavier_uniform_", "ultralytics.utils.tal.TORCH_1_10", "ultralytics.utils.tal.dist2bbox", "ultralytics.utils.tal.dist2rbox", "ultralytics.utils.tal.make_anchors", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "ultralytics.utils.torch_utils.smart_inference_mode", "block.DFL", "block.SAVPE", "block.BNContrastiveHead", "block.ContrastiveHead", "block.Proto", "block.Residual", "block.SwiGLUFFN", "conv.Conv", "conv.DWConv", "transformer.MLP", "transformer.DeformableTransformerDecoder", "transformer.DeformableTransformerDecoderLayer", "utils.bias_init_with_prob", "utils.linear_init", "ultralytics.models.utils.ops.get_cdn_group", "Detect" ], "chunk_id": "class_OBB_85f3b3c3" }, { "content": "class Pose(Detect):\n \"\"\"\n YOLO Pose head for keypoints models.\n\n This class extends the Detect head to include keypoint prediction capabilities for pose estimation tasks.\n\n Attributes:\n kpt_shape (tuple): Number of keypoints and dimensions (2 for x,y or 3 for x,y,visible).\n nk (int): Total number of keypoint values.\n cv4 (nn.ModuleList): Convolution layers for keypoint prediction.\n\n Methods:\n forward: Perform forward pass through YOLO model and return predictions.\n kpts_decode: Decode keypoints from predictions.\n\n Examples:\n Create a pose detection head\n >>> pose = Pose(nc=80, kpt_shape=(17, 3), ch=(256, 512, 1024))\n >>> x = [torch.randn(1, 256, 80, 80), torch.randn(1, 512, 40, 40), torch.randn(1, 1024, 20, 20)]\n >>> outputs = pose(x)\n \"\"\"\n\n def __init__(self, nc: int = 80, kpt_shape: Tuple = (17, 3), ch: Tuple = ()):\n \"\"\"\n Initialize YOLO network with default parameters and Convolutional Layers.\n\n Args:\n nc (int): Number of classes.\n kpt_shape (tuple): Number of keypoints, number of dims (2 for x,y or 3 for x,y,visible).\n ch (tuple): Tuple of channel sizes from backbone feature maps.\n \"\"\"\n super().__init__(nc, ch)\n self.kpt_shape = kpt_shape # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible)\n self.nk = kpt_shape[0] * kpt_shape[1] # number of keypoints total\n\n c4 = max(ch[0] // 4, self.nk)\n self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nk, 1)) for x in ch)\n\n def forward(self, x: List[torch.Tensor]) -> Union[torch.Tensor, Tuple]:\n \"\"\"Perform forward pass through YOLO model and return predictions.\"\"\"\n bs = x[0].shape[0] # batch size\n kpt = torch.cat([self.cv4[i](x[i]).view(bs, self.nk, -1) for i in range(self.nl)], -1) # (bs, 17*3, h*w)\n x = Detect.forward(self, x)\n if self.training:\n return x, kpt\n pred_kpt = self.kpts_decode(bs, kpt)\n if self.export and self.format == \"imx\":\n return (*x, pred_kpt.permute(0, 2, 1))\n return torch.cat([x, pred_kpt], 1) if self.export else (torch.cat([x[0], pred_kpt], 1), (x[1], kpt))\n\n def kpts_decode(self, bs: int, kpts: torch.Tensor) -> torch.Tensor:\n \"\"\"Decode keypoints from predictions.\"\"\"\n ndim = self.kpt_shape[1]\n if self.export:\n if self.format in {\n \"tflite\",\n \"edgetpu\",\n }: # required for TFLite export to avoid 'PLACEHOLDER_FOR_GREATER_OP_CODES' bug\n # Precompute normalization factor to increase numerical stability\n y = kpts.view(bs, *self.kpt_shape, -1)\n grid_h, grid_w = self.shape[2], self.shape[3]\n grid_size = torch.tensor([grid_w, grid_h], device=y.device).reshape(1, 2, 1)\n norm = self.strides / (self.stride[0] * grid_size)\n a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * norm\n else:\n # NCNN fix\n y = kpts.view(bs, *self.kpt_shape, -1)\n a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * self.strides\n if ndim == 3:\n a = torch.cat((a, y[:, :, 2:3].sigmoid()), 2)\n return a.view(bs, self.nk, -1)\n else:\n y = kpts.clone()\n if ndim == 3:\n y[:, 2::ndim] = y[:, 2::ndim].sigmoid() # sigmoid (WARNING: inplace .sigmoid_() Apple MPS bug)\n y[:, 0::ndim] = (y[:, 0::ndim] * 2.0 + (self.anchors[0] - 0.5)) * self.strides\n y[:, 1::ndim] = (y[:, 1::ndim] * 2.0 + (self.anchors[1] - 0.5)) * self.strides\n return y", "chunk_type": "class", "name": "Pose", "file_path": "ultralytics\\ultralytics\\nn\\modules\\head.py", "start_line": 337, "end_line": 414, "start_col": 0, "end_col": 20, "parent_name": null, "docstring": "YOLO Pose head for keypoints models.\n\nThis class extends the Detect head to include keypoint prediction capabilities for pose estimation tasks.\n\nAttributes:\n kpt_shape (tuple): Number of keypoints and dimensions (2 for x,y or 3 for x,y,visible).\n nk (int): Total number of keypoint values.\n cv4 (nn.ModuleList): Convolution layers for keypoint prediction.\n\nMethods:\n forward: Perform forward pass through YOLO model and return predictions.\n kpts_decode: Decode keypoints from predictions.\n\nExamples:\n Create a pose detection head\n >>> pose = Pose(nc=80, kpt_shape=(17, 3), ch=(256, 512, 1024))\n >>> x = [torch.randn(1, 256, 80, 80), torch.randn(1, 512, 40, 40), torch.randn(1, 1024, 20, 20)]\n >>> outputs = pose(x)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "copy", "math", "typing.List", "typing.Optional", "typing.Tuple", "typing.Union", "torch", "torch.nn", "torch.nn.functional", "torch.nn.init.constant_", "torch.nn.init.xavier_uniform_", "ultralytics.utils.tal.TORCH_1_10", "ultralytics.utils.tal.dist2bbox", "ultralytics.utils.tal.dist2rbox", "ultralytics.utils.tal.make_anchors", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "ultralytics.utils.torch_utils.smart_inference_mode", "block.DFL", "block.SAVPE", "block.BNContrastiveHead", "block.ContrastiveHead", "block.Proto", "block.Residual", "block.SwiGLUFFN", "conv.Conv", "conv.DWConv", "transformer.MLP", "transformer.DeformableTransformerDecoder", "transformer.DeformableTransformerDecoderLayer", "utils.bias_init_with_prob", "utils.linear_init", "ultralytics.models.utils.ops.get_cdn_group", "Detect" ], "chunk_id": "class_Pose_852b4e1a" }, { "content": "class Classify(nn.Module):\n \"\"\"\n YOLO classification head, i.e. x(b,c1,20,20) to x(b,c2).\n\n This class implements a classification head that transforms feature maps into class predictions.\n\n Attributes:\n export (bool): Export mode flag.\n conv (Conv): Convolutional layer for feature transformation.\n pool (nn.AdaptiveAvgPool2d): Global average pooling layer.\n drop (nn.Dropout): Dropout layer for regularization.\n linear (nn.Linear): Linear layer for final classification.\n\n Methods:\n forward: Perform forward pass of the YOLO model on input image data.\n\n Examples:\n Create a classification head\n >>> classify = Classify(c1=1024, c2=1000)\n >>> x = torch.randn(1, 1024, 20, 20)\n >>> output = classify(x)\n \"\"\"\n\n export = False # export mode\n\n def __init__(self, c1: int, c2: int, k: int = 1, s: int = 1, p: Optional[int] = None, g: int = 1):\n \"\"\"\n Initialize YOLO classification head to transform input tensor from (b,c1,20,20) to (b,c2) shape.\n\n Args:\n c1 (int): Number of input channels.\n c2 (int): Number of output classes.\n k (int, optional): Kernel size.\n s (int, optional): Stride.\n p (int, optional): Padding.\n g (int, optional): Groups.\n \"\"\"\n super().__init__()\n c_ = 1280 # efficientnet_b0 size\n self.conv = Conv(c1, c_, k, s, p, g)\n self.pool = nn.AdaptiveAvgPool2d(1) # to x(b,c_,1,1)\n self.drop = nn.Dropout(p=0.0, inplace=True)\n self.linear = nn.Linear(c_, c2) # to x(b,c2)\n\n def forward(self, x: Union[List[torch.Tensor], torch.Tensor]) -> Union[torch.Tensor, Tuple]:\n \"\"\"Perform forward pass of the YOLO model on input image data.\"\"\"\n if isinstance(x, list):\n x = torch.cat(x, 1)\n x = self.linear(self.drop(self.pool(self.conv(x)).flatten(1)))\n if self.training:\n return x\n y = x.softmax(1) # get final output\n return y if self.export else (y, x)", "chunk_type": "class", "name": "Classify", "file_path": "ultralytics\\ultralytics\\nn\\modules\\head.py", "start_line": 417, "end_line": 469, "start_col": 0, "end_col": 43, "parent_name": null, "docstring": "YOLO classification head, i.e. x(b,c1,20,20) to x(b,c2).\n\nThis class implements a classification head that transforms feature maps into class predictions.\n\nAttributes:\n export (bool): Export mode flag.\n conv (Conv): Convolutional layer for feature transformation.\n pool (nn.AdaptiveAvgPool2d): Global average pooling layer.\n drop (nn.Dropout): Dropout layer for regularization.\n linear (nn.Linear): Linear layer for final classification.\n\nMethods:\n forward: Perform forward pass of the YOLO model on input image data.\n\nExamples:\n Create a classification head\n >>> classify = Classify(c1=1024, c2=1000)\n >>> x = torch.randn(1, 1024, 20, 20)\n >>> output = classify(x)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "copy", "math", "typing.List", "typing.Optional", "typing.Tuple", "typing.Union", "torch", "torch.nn", "torch.nn.functional", "torch.nn.init.constant_", "torch.nn.init.xavier_uniform_", "ultralytics.utils.tal.TORCH_1_10", "ultralytics.utils.tal.dist2bbox", "ultralytics.utils.tal.dist2rbox", "ultralytics.utils.tal.make_anchors", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "ultralytics.utils.torch_utils.smart_inference_mode", "block.DFL", "block.SAVPE", "block.BNContrastiveHead", "block.ContrastiveHead", "block.Proto", "block.Residual", "block.SwiGLUFFN", "conv.Conv", "conv.DWConv", "transformer.MLP", "transformer.DeformableTransformerDecoder", "transformer.DeformableTransformerDecoderLayer", "utils.bias_init_with_prob", "utils.linear_init", "ultralytics.models.utils.ops.get_cdn_group", "nn.Module" ], "chunk_id": "class_Classify_2697ab2b" }, { "content": "class WorldDetect(Detect):\n \"\"\"\n Head for integrating YOLO detection models with semantic understanding from text embeddings.\n\n This class extends the standard Detect head to incorporate text embeddings for enhanced semantic understanding\n in object detection tasks.\n\n Attributes:\n cv3 (nn.ModuleList): Convolution layers for embedding features.\n cv4 (nn.ModuleList): Contrastive head layers for text-vision alignment.\n\n Methods:\n forward: Concatenate and return predicted bounding boxes and class probabilities.\n bias_init: Initialize detection head biases.\n\n Examples:\n Create a WorldDetect head\n >>> world_detect = WorldDetect(nc=80, embed=512, with_bn=False, ch=(256, 512, 1024))\n >>> x = [torch.randn(1, 256, 80, 80), torch.randn(1, 512, 40, 40), torch.randn(1, 1024, 20, 20)]\n >>> text = torch.randn(1, 80, 512)\n >>> outputs = world_detect(x, text)\n \"\"\"\n\n def __init__(self, nc: int = 80, embed: int = 512, with_bn: bool = False, ch: Tuple = ()):\n \"\"\"\n Initialize YOLO detection layer with nc classes and layer channels ch.\n\n Args:\n nc (int): Number of classes.\n embed (int): Embedding dimension.\n with_bn (bool): Whether to use batch normalization in contrastive head.\n ch (tuple): Tuple of channel sizes from backbone feature maps.\n \"\"\"\n super().__init__(nc, ch)\n c3 = max(ch[0], min(self.nc, 100))\n self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, embed, 1)) for x in ch)\n self.cv4 = nn.ModuleList(BNContrastiveHead(embed) if with_bn else ContrastiveHead() for _ in ch)\n\n def forward(self, x: List[torch.Tensor], text: torch.Tensor) -> Union[List[torch.Tensor], Tuple]:\n \"\"\"Concatenate and return predicted bounding boxes and class probabilities.\"\"\"\n for i in range(self.nl):\n x[i] = torch.cat((self.cv2[i](x[i]), self.cv4[i](self.cv3[i](x[i]), text)), 1)\n if self.training:\n return x\n self.no = self.nc + self.reg_max * 4 # self.nc could be changed when inference with different texts\n y = self._inference(x)\n return y if self.export else (y, x)\n\n def bias_init(self):\n \"\"\"Initialize Detect() biases, WARNING: requires stride availability.\"\"\"\n m = self # self.model[-1] # Detect() module\n # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1\n # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency\n for a, b, s in zip(m.cv2, m.cv3, m.stride): # from\n a[-1].bias.data[:] = 1.0 # box", "chunk_type": "class", "name": "WorldDetect", "file_path": "ultralytics\\ultralytics\\nn\\modules\\head.py", "start_line": 472, "end_line": 526, "start_col": 0, "end_col": 36, "parent_name": null, "docstring": "Head for integrating YOLO detection models with semantic understanding from text embeddings.\n\nThis class extends the standard Detect head to incorporate text embeddings for enhanced semantic understanding\nin object detection tasks.\n\nAttributes:\n cv3 (nn.ModuleList): Convolution layers for embedding features.\n cv4 (nn.ModuleList): Contrastive head layers for text-vision alignment.\n\nMethods:\n forward: Concatenate and return predicted bounding boxes and class probabilities.\n bias_init: Initialize detection head biases.\n\nExamples:\n Create a WorldDetect head\n >>> world_detect = WorldDetect(nc=80, embed=512, with_bn=False, ch=(256, 512, 1024))\n >>> x = [torch.randn(1, 256, 80, 80), torch.randn(1, 512, 40, 40), torch.randn(1, 1024, 20, 20)]\n >>> text = torch.randn(1, 80, 512)\n >>> outputs = world_detect(x, text)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "copy", "math", "typing.List", "typing.Optional", "typing.Tuple", "typing.Union", "torch", "torch.nn", "torch.nn.functional", "torch.nn.init.constant_", "torch.nn.init.xavier_uniform_", "ultralytics.utils.tal.TORCH_1_10", "ultralytics.utils.tal.dist2bbox", "ultralytics.utils.tal.dist2rbox", "ultralytics.utils.tal.make_anchors", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "ultralytics.utils.torch_utils.smart_inference_mode", "block.DFL", "block.SAVPE", "block.BNContrastiveHead", "block.ContrastiveHead", "block.Proto", "block.Residual", "block.SwiGLUFFN", "conv.Conv", "conv.DWConv", "transformer.MLP", "transformer.DeformableTransformerDecoder", "transformer.DeformableTransformerDecoderLayer", "utils.bias_init_with_prob", "utils.linear_init", "ultralytics.models.utils.ops.get_cdn_group", "Detect" ], "chunk_id": "class_WorldDetect_4c307729" }, { "content": "class LRPCHead(nn.Module):\n \"\"\"\n Lightweight Region Proposal and Classification Head for efficient object detection.\n\n This head combines region proposal filtering with classification to enable efficient detection with\n dynamic vocabulary support.\n\n Attributes:\n vocab (nn.Module): Vocabulary/classification layer.\n pf (nn.Module): Proposal filter module.\n loc (nn.Module): Localization module.\n enabled (bool): Whether the head is enabled.\n\n Methods:\n conv2linear: Convert a 1x1 convolutional layer to a linear layer.\n forward: Process classification and localization features to generate detection proposals.\n\n Examples:\n Create an LRPC head\n >>> vocab = nn.Conv2d(256, 80, 1)\n >>> pf = nn.Conv2d(256, 1, 1)\n >>> loc = nn.Conv2d(256, 4, 1)\n >>> head = LRPCHead(vocab, pf, loc, enabled=True)\n \"\"\"\n\n def __init__(self, vocab: nn.Module, pf: nn.Module, loc: nn.Module, enabled: bool = True):\n \"\"\"\n Initialize LRPCHead with vocabulary, proposal filter, and localization components.\n\n Args:\n vocab (nn.Module): Vocabulary/classification module.\n pf (nn.Module): Proposal filter module.\n loc (nn.Module): Localization module.\n enabled (bool): Whether to enable the head functionality.\n \"\"\"\n super().__init__()\n self.vocab = self.conv2linear(vocab) if enabled else vocab\n self.pf = pf\n self.loc = loc\n self.enabled = enabled\n\n def conv2linear(self, conv: nn.Conv2d) -> nn.Linear:\n \"\"\"Convert a 1x1 convolutional layer to a linear layer.\"\"\"\n assert isinstance(conv, nn.Conv2d) and conv.kernel_size == (1, 1)\n linear = nn.Linear(conv.in_channels, conv.out_channels)\n linear.weight.data = conv.weight.view(conv.out_channels, -1).data\n linear.bias.data = conv.bias.data\n return linear\n\n def forward(self, cls_feat: torch.Tensor, loc_feat: torch.Tensor, conf: float) -> Tuple[Tuple, torch.Tensor]:\n \"\"\"Process classification and localization features to generate detection proposals.\"\"\"\n if self.enabled:\n pf_score = self.pf(cls_feat)[0, 0].flatten(0)\n mask = pf_score.sigmoid() > conf\n cls_feat = cls_feat.flatten(2).transpose(-1, -2)\n cls_feat = self.vocab(cls_feat[:, mask] if conf else cls_feat * mask.unsqueeze(-1).int())\n return (self.loc(loc_feat), cls_feat.transpose(-1, -2)), mask\n else:\n cls_feat = self.vocab(cls_feat)\n loc_feat = self.loc(loc_feat)\n return (loc_feat, cls_feat.flatten(2)), torch.ones(\n cls_feat.shape[2] * cls_feat.shape[3], device=cls_feat.device, dtype=torch.bool\n )", "chunk_type": "class", "name": "LRPCHead", "file_path": "ultralytics\\ultralytics\\nn\\modules\\head.py", "start_line": 530, "end_line": 592, "start_col": 0, "end_col": 13, "parent_name": null, "docstring": "Lightweight Region Proposal and Classification Head for efficient object detection.\n\nThis head combines region proposal filtering with classification to enable efficient detection with\ndynamic vocabulary support.\n\nAttributes:\n vocab (nn.Module): Vocabulary/classification layer.\n pf (nn.Module): Proposal filter module.\n loc (nn.Module): Localization module.\n enabled (bool): Whether the head is enabled.\n\nMethods:\n conv2linear: Convert a 1x1 convolutional layer to a linear layer.\n forward: Process classification and localization features to generate detection proposals.\n\nExamples:\n Create an LRPC head\n >>> vocab = nn.Conv2d(256, 80, 1)\n >>> pf = nn.Conv2d(256, 1, 1)\n >>> loc = nn.Conv2d(256, 4, 1)\n >>> head = LRPCHead(vocab, pf, loc, enabled=True)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "copy", "math", "typing.List", "typing.Optional", "typing.Tuple", "typing.Union", "torch", "torch.nn", "torch.nn.functional", "torch.nn.init.constant_", "torch.nn.init.xavier_uniform_", "ultralytics.utils.tal.TORCH_1_10", "ultralytics.utils.tal.dist2bbox", "ultralytics.utils.tal.dist2rbox", "ultralytics.utils.tal.make_anchors", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "ultralytics.utils.torch_utils.smart_inference_mode", "block.DFL", "block.SAVPE", "block.BNContrastiveHead", "block.ContrastiveHead", "block.Proto", "block.Residual", "block.SwiGLUFFN", "conv.Conv", "conv.DWConv", "transformer.MLP", "transformer.DeformableTransformerDecoder", "transformer.DeformableTransformerDecoderLayer", "utils.bias_init_with_prob", "utils.linear_init", "ultralytics.models.utils.ops.get_cdn_group", "nn.Module" ], "chunk_id": "class_LRPCHead_4219ae44" }, { "content": "class YOLOEDetect(Detect):\n \"\"\"\n Head for integrating YOLO detection models with semantic understanding from text embeddings.\n\n This class extends the standard Detect head to support text-guided detection with enhanced semantic understanding\n through text embeddings and visual prompt embeddings.\n\n Attributes:\n is_fused (bool): Whether the model is fused for inference.\n cv3 (nn.ModuleList): Convolution layers for embedding features.\n cv4 (nn.ModuleList): Contrastive head layers for text-vision alignment.\n reprta (Residual): Residual block for text prompt embeddings.\n savpe (SAVPE): Spatial-aware visual prompt embeddings module.\n embed (int): Embedding dimension.\n\n Methods:\n fuse: Fuse text features with model weights for efficient inference.\n get_tpe: Get text prompt embeddings with normalization.\n get_vpe: Get visual prompt embeddings with spatial awareness.\n forward_lrpc: Process features with fused text embeddings for prompt-free model.\n forward: Process features with class prompt embeddings to generate detections.\n bias_init: Initialize biases for detection heads.\n\n Examples:\n Create a YOLOEDetect head\n >>> yoloe_detect = YOLOEDetect(nc=80, embed=512, with_bn=True, ch=(256, 512, 1024))\n >>> x = [torch.randn(1, 256, 80, 80), torch.randn(1, 512, 40, 40), torch.randn(1, 1024, 20, 20)]\n >>> cls_pe = torch.randn(1, 80, 512)\n >>> outputs = yoloe_detect(x, cls_pe)\n \"\"\"\n\n is_fused = False\n\n def __init__(self, nc: int = 80, embed: int = 512, with_bn: bool = False, ch: Tuple = ()):\n \"\"\"\n Initialize YOLO detection layer with nc classes and layer channels ch.\n\n Args:\n nc (int): Number of classes.\n embed (int): Embedding dimension.\n with_bn (bool): Whether to use batch normalization in contrastive head.\n ch (tuple): Tuple of channel sizes from backbone feature maps.\n \"\"\"\n super().__init__(nc, ch)\n c3 = max(ch[0], min(self.nc, 100))\n assert c3 <= embed\n assert with_bn is True\n self.cv3 = (\n nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, embed, 1)) for x in ch)\n if self.legacy\n else nn.ModuleList(\n nn.Sequential(\n nn.Sequential(DWConv(x, x, 3), Conv(x, c3, 1)),\n nn.Sequential(DWConv(c3, c3, 3), Conv(c3, c3, 1)),\n nn.Conv2d(c3, embed, 1),\n )\n for x in ch\n )\n )\n\n self.cv4 = nn.ModuleList(BNContrastiveHead(embed) if with_bn else ContrastiveHead() for _ in ch)\n\n self.reprta = Residual(SwiGLUFFN(embed, embed))\n self.savpe = SAVPE(ch, c3, embed)\n self.embed = embed\n\n @smart_inference_mode()\n def fuse(self, txt_feats: torch.Tensor):\n \"\"\"Fuse text features with model weights for efficient inference.\"\"\"\n if self.is_fused:\n return\n\n assert not self.training\n txt_feats = txt_feats.to(torch.float32).squeeze(0)\n for cls_head, bn_head in zip(self.cv3, self.cv4):\n assert isinstance(cls_head, nn.Sequential)\n assert isinstance(bn_head, BNContrastiveHead)\n conv = cls_head[-1]\n assert isinstance(conv, nn.Conv2d)\n logit_scale = bn_head.logit_scale\n bias = bn_head.bias\n norm = bn_head.norm\n\n t = txt_feats * logit_scale.exp()\n conv: nn.Conv2d = fuse_conv_and_bn(conv, norm)\n\n w = conv.weight.data.squeeze(-1).squeeze(-1)\n b = conv.bias.data\n\n w = t @ w\n b1 = (t @ b.reshape(-1).unsqueeze(-1)).squeeze(-1)\n b2 = torch.ones_like(b1) * bias\n\n conv = (\n nn.Conv2d(\n conv.in_channels,\n w.shape[0],\n kernel_size=1,\n )\n .requires_grad_(False)\n .to(conv.weight.device)\n )\n\n conv.weight.data.copy_(w.unsqueeze(-1).unsqueeze(-1))\n conv.bias.data.copy_(b1 + b2)\n cls_head[-1] = conv\n\n bn_head.fuse()\n\n del self.reprta\n self.reprta = nn.Identity()\n self.is_fused = True\n\n def get_tpe(self, tpe: Optional[torch.Tensor]) -> Optional[torch.Tensor]:\n \"\"\"Get text prompt embeddings with normalization.\"\"\"\n return None if tpe is None else F.normalize(self.reprta(tpe), dim=-1, p=2)\n\n def get_vpe(self, x: List[torch.Tensor], vpe: torch.Tensor) -> torch.Tensor:\n \"\"\"Get visual prompt embeddings with spatial awareness.\"\"\"\n if vpe.shape[1] == 0: # no visual prompt embeddings\n return torch.zeros(x[0].shape[0], 0, self.embed, device=x[0].device)\n if vpe.ndim == 4: # (B, N, H, W)\n vpe = self.savpe(x, vpe)\n assert vpe.ndim == 3 # (B, N, D)\n return vpe\n\n def forward_lrpc(self, x: List[torch.Tensor], return_mask: bool = False) -> Union[torch.Tensor, Tuple]:\n \"\"\"Process features with fused text embeddings to generate detections for prompt-free model.\"\"\"\n masks = []\n assert self.is_fused, \"Prompt-free inference requires model to be fused!\"\n for i in range(self.nl):\n cls_feat = self.cv3[i](x[i])\n loc_feat = self.cv2[i](x[i])\n assert isinstance(self.lrpc[i], LRPCHead)\n x[i], mask = self.lrpc[i](\n cls_feat, loc_feat, 0 if self.export and not self.dynamic else getattr(self, \"conf\", 0.001)\n )\n masks.append(mask)\n shape = x[0][0].shape\n if self.dynamic or self.shape != shape:\n self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors([b[0] for b in x], self.stride, 0.5))\n self.shape = shape\n box = torch.cat([xi[0].view(shape[0], self.reg_max * 4, -1) for xi in x], 2)\n cls = torch.cat([xi[1] for xi in x], 2)\n\n if self.export and self.format in {\"tflite\", \"edgetpu\"}:\n # Precompute normalization factor to increase numerical stability\n # See https://github.com/ultralytics/ultralytics/issues/7371\n grid_h = shape[2]\n grid_w = shape[3]\n grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=box.device).reshape(1, 4, 1)\n norm = self.strides / (self.stride[0] * grid_size)\n dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2])\n else:\n dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides\n\n mask = torch.cat(masks)\n y = torch.cat((dbox if self.export and not self.dynamic else dbox[..., mask], cls.sigmoid()), 1)\n\n if return_mask:\n return (y, mask) if self.export else ((y, x), mask)\n else:\n return y if self.export else (y, x)\n\n def forward(\n self, x: List[torch.Tensor], cls_pe: torch.Tensor, return_mask: bool = False\n ) -> Union[torch.Tensor, Tuple]:\n \"\"\"Process features with class prompt embeddings to generate detections.\"\"\"\n if hasattr(self, \"lrpc\"): # for prompt-free inference\n return self.forward_lrpc(x, return_mask)\n for i in range(self.nl):\n x[i] = torch.cat((self.cv2[i](x[i]), self.cv4[i](self.cv3[i](x[i]), cls_pe)), 1)\n if self.training:\n return x\n self.no = self.nc + self.reg_max * 4 # self.nc could be changed when inference with different texts\n y = self._inference(x)\n return y if self.export else (y, x)\n\n def bias_init(self):\n \"\"\"Initialize biases for detection heads.\"\"\"\n m = self # self.model[-1] # Detect() module\n # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1\n # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency\n for a, b, c, s in zip(m.cv2, m.cv3, m.cv4, m.stride): # from\n a[-1].bias.data[:] = 1.0 # box\n # b[-1].bias.data[:] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)\n b[-1].bias.data[:] = 0.0\n c.bias.data[:] = math.log(5 / m.nc / (640 / s) ** 2)", "chunk_type": "class", "name": "YOLOEDetect", "file_path": "ultralytics\\ultralytics\\nn\\modules\\head.py", "start_line": 595, "end_line": 782, "start_col": 0, "end_col": 64, "parent_name": null, "docstring": "Head for integrating YOLO detection models with semantic understanding from text embeddings.\n\nThis class extends the standard Detect head to support text-guided detection with enhanced semantic understanding\nthrough text embeddings and visual prompt embeddings.\n\nAttributes:\n is_fused (bool): Whether the model is fused for inference.\n cv3 (nn.ModuleList): Convolution layers for embedding features.\n cv4 (nn.ModuleList): Contrastive head layers for text-vision alignment.\n reprta (Residual): Residual block for text prompt embeddings.\n savpe (SAVPE): Spatial-aware visual prompt embeddings module.\n embed (int): Embedding dimension.\n\nMethods:\n fuse: Fuse text features with model weights for efficient inference.\n get_tpe: Get text prompt embeddings with normalization.\n get_vpe: Get visual prompt embeddings with spatial awareness.\n forward_lrpc: Process features with fused text embeddings for prompt-free model.\n forward: Process features with class prompt embeddings to generate detections.\n bias_init: Initialize biases for detection heads.\n\nExamples:\n Create a YOLOEDetect head\n >>> yoloe_detect = YOLOEDetect(nc=80, embed=512, with_bn=True, ch=(256, 512, 1024))\n >>> x = [torch.randn(1, 256, 80, 80), torch.randn(1, 512, 40, 40), torch.randn(1, 1024, 20, 20)]\n >>> cls_pe = torch.randn(1, 80, 512)\n >>> outputs = yoloe_detect(x, cls_pe)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "copy", "math", "typing.List", "typing.Optional", "typing.Tuple", "typing.Union", "torch", "torch.nn", "torch.nn.functional", "torch.nn.init.constant_", "torch.nn.init.xavier_uniform_", "ultralytics.utils.tal.TORCH_1_10", "ultralytics.utils.tal.dist2bbox", "ultralytics.utils.tal.dist2rbox", "ultralytics.utils.tal.make_anchors", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "ultralytics.utils.torch_utils.smart_inference_mode", "block.DFL", "block.SAVPE", "block.BNContrastiveHead", "block.ContrastiveHead", "block.Proto", "block.Residual", "block.SwiGLUFFN", "conv.Conv", "conv.DWConv", "transformer.MLP", "transformer.DeformableTransformerDecoder", "transformer.DeformableTransformerDecoderLayer", "utils.bias_init_with_prob", "utils.linear_init", "ultralytics.models.utils.ops.get_cdn_group", "Detect" ], "chunk_id": "class_YOLOEDetect_9945901b" }, { "content": "class YOLOESegment(YOLOEDetect):\n \"\"\"\n YOLO segmentation head with text embedding capabilities.\n\n This class extends YOLOEDetect to include mask prediction capabilities for instance segmentation tasks\n with text-guided semantic understanding.\n\n Attributes:\n nm (int): Number of masks.\n npr (int): Number of protos.\n proto (Proto): Prototype generation module.\n cv5 (nn.ModuleList): Convolution layers for mask coefficients.\n\n Methods:\n forward: Return model outputs and mask coefficients.\n\n Examples:\n Create a YOLOESegment head\n >>> yoloe_segment = YOLOESegment(nc=80, nm=32, npr=256, embed=512, with_bn=True, ch=(256, 512, 1024))\n >>> x = [torch.randn(1, 256, 80, 80), torch.randn(1, 512, 40, 40), torch.randn(1, 1024, 20, 20)]\n >>> text = torch.randn(1, 80, 512)\n >>> outputs = yoloe_segment(x, text)\n \"\"\"\n\n def __init__(\n self, nc: int = 80, nm: int = 32, npr: int = 256, embed: int = 512, with_bn: bool = False, ch: Tuple = ()\n ):\n \"\"\"\n Initialize YOLOESegment with class count, mask parameters, and embedding dimensions.\n\n Args:\n nc (int): Number of classes.\n nm (int): Number of masks.\n npr (int): Number of protos.\n embed (int): Embedding dimension.\n with_bn (bool): Whether to use batch normalization in contrastive head.\n ch (tuple): Tuple of channel sizes from backbone feature maps.\n \"\"\"\n super().__init__(nc, embed, with_bn, ch)\n self.nm = nm\n self.npr = npr\n self.proto = Proto(ch[0], self.npr, self.nm)\n\n c5 = max(ch[0] // 4, self.nm)\n self.cv5 = nn.ModuleList(nn.Sequential(Conv(x, c5, 3), Conv(c5, c5, 3), nn.Conv2d(c5, self.nm, 1)) for x in ch)\n\n def forward(self, x: List[torch.Tensor], text: torch.Tensor) -> Union[Tuple, torch.Tensor]:\n \"\"\"Return model outputs and mask coefficients if training, otherwise return outputs and mask coefficients.\"\"\"\n p = self.proto(x[0]) # mask protos\n bs = p.shape[0] # batch size\n\n mc = torch.cat([self.cv5[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2) # mask coefficients\n has_lrpc = hasattr(self, \"lrpc\")\n\n if not has_lrpc:\n x = YOLOEDetect.forward(self, x, text)\n else:\n x, mask = YOLOEDetect.forward(self, x, text, return_mask=True)\n\n if self.training:\n return x, mc, p\n\n if has_lrpc:\n mc = (mc * mask.int()) if self.export and not self.dynamic else mc[..., mask]\n\n return (torch.cat([x, mc], 1), p) if self.export else (torch.cat([x[0], mc], 1), (x[1], mc, p))", "chunk_type": "class", "name": "YOLOESegment", "file_path": "ultralytics\\ultralytics\\nn\\modules\\head.py", "start_line": 785, "end_line": 850, "start_col": 0, "end_col": 103, "parent_name": null, "docstring": "YOLO segmentation head with text embedding capabilities.\n\nThis class extends YOLOEDetect to include mask prediction capabilities for instance segmentation tasks\nwith text-guided semantic understanding.\n\nAttributes:\n nm (int): Number of masks.\n npr (int): Number of protos.\n proto (Proto): Prototype generation module.\n cv5 (nn.ModuleList): Convolution layers for mask coefficients.\n\nMethods:\n forward: Return model outputs and mask coefficients.\n\nExamples:\n Create a YOLOESegment head\n >>> yoloe_segment = YOLOESegment(nc=80, nm=32, npr=256, embed=512, with_bn=True, ch=(256, 512, 1024))\n >>> x = [torch.randn(1, 256, 80, 80), torch.randn(1, 512, 40, 40), torch.randn(1, 1024, 20, 20)]\n >>> text = torch.randn(1, 80, 512)\n >>> outputs = yoloe_segment(x, text)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "copy", "math", "typing.List", "typing.Optional", "typing.Tuple", "typing.Union", "torch", "torch.nn", "torch.nn.functional", "torch.nn.init.constant_", "torch.nn.init.xavier_uniform_", "ultralytics.utils.tal.TORCH_1_10", "ultralytics.utils.tal.dist2bbox", "ultralytics.utils.tal.dist2rbox", "ultralytics.utils.tal.make_anchors", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "ultralytics.utils.torch_utils.smart_inference_mode", "block.DFL", "block.SAVPE", "block.BNContrastiveHead", "block.ContrastiveHead", "block.Proto", "block.Residual", "block.SwiGLUFFN", "conv.Conv", "conv.DWConv", "transformer.MLP", "transformer.DeformableTransformerDecoder", "transformer.DeformableTransformerDecoderLayer", "utils.bias_init_with_prob", "utils.linear_init", "ultralytics.models.utils.ops.get_cdn_group", "YOLOEDetect" ], "chunk_id": "class_YOLOESegment_8116a2f2" }, { "content": "class RTDETRDecoder(nn.Module):\n \"\"\"\n Real-Time Deformable Transformer Decoder (RTDETRDecoder) module for object detection.\n\n This decoder module utilizes Transformer architecture along with deformable convolutions to predict bounding boxes\n and class labels for objects in an image. It integrates features from multiple layers and runs through a series of\n Transformer decoder layers to output the final predictions.\n\n Attributes:\n export (bool): Export mode flag.\n hidden_dim (int): Dimension of hidden layers.\n nhead (int): Number of heads in multi-head attention.\n nl (int): Number of feature levels.\n nc (int): Number of classes.\n num_queries (int): Number of query points.\n num_decoder_layers (int): Number of decoder layers.\n input_proj (nn.ModuleList): Input projection layers for backbone features.\n decoder (DeformableTransformerDecoder): Transformer decoder module.\n denoising_class_embed (nn.Embedding): Class embeddings for denoising.\n num_denoising (int): Number of denoising queries.\n label_noise_ratio (float): Label noise ratio for training.\n box_noise_scale (float): Box noise scale for training.\n learnt_init_query (bool): Whether to learn initial query embeddings.\n tgt_embed (nn.Embedding): Target embeddings for queries.\n query_pos_head (MLP): Query position head.\n enc_output (nn.Sequential): Encoder output layers.\n enc_score_head (nn.Linear): Encoder score prediction head.\n enc_bbox_head (MLP): Encoder bbox prediction head.\n dec_score_head (nn.ModuleList): Decoder score prediction heads.\n dec_bbox_head (nn.ModuleList): Decoder bbox prediction heads.\n\n Methods:\n forward: Run forward pass and return bounding box and classification scores.\n\n Examples:\n Create an RTDETRDecoder\n >>> decoder = RTDETRDecoder(nc=80, ch=(512, 1024, 2048), hd=256, nq=300)\n >>> x = [torch.randn(1, 512, 64, 64), torch.randn(1, 1024, 32, 32), torch.randn(1, 2048, 16, 16)]\n >>> outputs = decoder(x)\n \"\"\"\n\n export = False # export mode\n\n def __init__(\n self,\n nc: int = 80,\n ch: Tuple = (512, 1024, 2048),\n hd: int = 256, # hidden dim\n nq: int = 300, # num queries\n ndp: int = 4, # num decoder points\n nh: int = 8, # num head\n ndl: int = 6, # num decoder layers\n d_ffn: int = 1024, # dim of feedforward\n dropout: float = 0.0,\n act: nn.Module = nn.ReLU(),\n eval_idx: int = -1,\n # Training args\n nd: int = 100, # num denoising\n label_noise_ratio: float = 0.5,\n box_noise_scale: float = 1.0,\n learnt_init_query: bool = False,\n ):\n \"\"\"\n Initialize the RTDETRDecoder module with the given parameters.\n\n Args:\n nc (int): Number of classes.\n ch (tuple): Channels in the backbone feature maps.\n hd (int): Dimension of hidden layers.\n nq (int): Number of query points.\n ndp (int): Number of decoder points.\n nh (int): Number of heads in multi-head attention.\n ndl (int): Number of decoder layers.\n d_ffn (int): Dimension of the feed-forward networks.\n dropout (float): Dropout rate.\n act (nn.Module): Activation function.\n eval_idx (int): Evaluation index.\n nd (int): Number of denoising.\n label_noise_ratio (float): Label noise ratio.\n box_noise_scale (float): Box noise scale.\n learnt_init_query (bool): Whether to learn initial query embeddings.\n \"\"\"\n super().__init__()\n self.hidden_dim = hd\n self.nhead = nh\n self.nl = len(ch) # num level\n self.nc = nc\n self.num_queries = nq\n self.num_decoder_layers = ndl\n\n # Backbone feature projection\n self.input_proj = nn.ModuleList(nn.Sequential(nn.Conv2d(x, hd, 1, bias=False), nn.BatchNorm2d(hd)) for x in ch)\n # NOTE: simplified version but it's not consistent with .pt weights.\n # self.input_proj = nn.ModuleList(Conv(x, hd, act=False) for x in ch)\n\n # Transformer module\n decoder_layer = DeformableTransformerDecoderLayer(hd, nh, d_ffn, dropout, act, self.nl, ndp)\n self.decoder = DeformableTransformerDecoder(hd, decoder_layer, ndl, eval_idx)\n\n # Denoising part\n self.denoising_class_embed = nn.Embedding(nc, hd)\n self.num_denoising = nd\n self.label_noise_ratio = label_noise_ratio\n self.box_noise_scale = box_noise_scale\n\n # Decoder embedding\n self.learnt_init_query = learnt_init_query\n if learnt_init_query:\n self.tgt_embed = nn.Embedding(nq, hd)\n self.query_pos_head = MLP(4, 2 * hd, hd, num_layers=2)\n\n # Encoder head\n self.enc_output = nn.Sequential(nn.Linear(hd, hd), nn.LayerNorm(hd))\n self.enc_score_head = nn.Linear(hd, nc)\n self.enc_bbox_head = MLP(hd, hd, 4, num_layers=3)\n\n # Decoder head\n self.dec_score_head = nn.ModuleList([nn.Linear(hd, nc) for _ in range(ndl)])\n self.dec_bbox_head = nn.ModuleList([MLP(hd, hd, 4, num_layers=3) for _ in range(ndl)])\n\n self._reset_parameters()\n\n def forward(self, x: List[torch.Tensor], batch: Optional[dict] = None) -> Union[Tuple, torch.Tensor]:\n \"\"\"\n Run the forward pass of the module, returning bounding box and classification scores for the input.\n\n Args:\n x (List[torch.Tensor]): List of feature maps from the backbone.\n batch (dict, optional): Batch information for training.\n\n Returns:\n outputs (tuple | torch.Tensor): During training, returns a tuple of bounding boxes, scores, and other\n metadata. During inference, returns a tensor of shape (bs, 300, 4+nc) containing bounding boxes and\n class scores.\n \"\"\"\n from ultralytics.models.utils.ops import get_cdn_group\n\n # Input projection and embedding\n feats, shapes = self._get_encoder_input(x)\n\n # Prepare denoising training\n dn_embed, dn_bbox, attn_mask, dn_meta = get_cdn_group(\n batch,\n self.nc,\n self.num_queries,\n self.denoising_class_embed.weight,\n self.num_denoising,\n self.label_noise_ratio,\n self.box_noise_scale,\n self.training,\n )\n\n embed, refer_bbox, enc_bboxes, enc_scores = self._get_decoder_input(feats, shapes, dn_embed, dn_bbox)\n\n # Decoder\n dec_bboxes, dec_scores = self.decoder(\n embed,\n refer_bbox,\n feats,\n shapes,\n self.dec_bbox_head,\n self.dec_score_head,\n self.query_pos_head,\n attn_mask=attn_mask,\n )\n x = dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta\n if self.training:\n return x\n # (bs, 300, 4+nc)\n y = torch.cat((dec_bboxes.squeeze(0), dec_scores.squeeze(0).sigmoid()), -1)\n return y if self.export else (y, x)\n\n def _generate_anchors(\n self,\n shapes: List[List[int]],\n grid_size: float = 0.05,\n dtype: torch.dtype = torch.float32,\n device: str = \"cpu\",\n eps: float = 1e-2,\n ) -> Tuple[torch.Tensor, torch.Tensor]:\n \"\"\"\n Generate anchor bounding boxes for given shapes with specific grid size and validate them.\n\n Args:\n shapes (list): List of feature map shapes.\n grid_size (float, optional): Base size of grid cells.\n dtype (torch.dtype, optional): Data type for tensors.\n device (str, optional): Device to create tensors on.\n eps (float, optional): Small value for numerical stability.\n\n Returns:\n anchors (torch.Tensor): Generated anchor boxes.\n valid_mask (torch.Tensor): Valid mask for anchors.\n \"\"\"\n anchors = []\n for i, (h, w) in enumerate(shapes):\n sy = torch.arange(end=h, dtype=dtype, device=device)\n sx = torch.arange(end=w, dtype=dtype, device=device)\n grid_y, grid_x = torch.meshgrid(sy, sx, indexing=\"ij\") if TORCH_1_10 else torch.meshgrid(sy, sx)\n grid_xy = torch.stack([grid_x, grid_y], -1) # (h, w, 2)\n\n valid_WH = torch.tensor([w, h], dtype=dtype, device=device)\n grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_WH # (1, h, w, 2)\n wh = torch.ones_like(grid_xy, dtype=dtype, device=device) * grid_size * (2.0**i)\n anchors.append(torch.cat([grid_xy, wh], -1).view(-1, h * w, 4)) # (1, h*w, 4)\n\n anchors = torch.cat(anchors, 1) # (1, h*w*nl, 4)\n valid_mask = ((anchors > eps) & (anchors < 1 - eps)).all(-1, keepdim=True) # 1, h*w*nl, 1\n anchors = torch.log(anchors / (1 - anchors))\n anchors = anchors.masked_fill(~valid_mask, float(\"inf\"))\n return anchors, valid_mask\n\n def _get_encoder_input(self, x: List[torch.Tensor]) -> Tuple[torch.Tensor, List[List[int]]]:\n \"\"\"\n Process and return encoder inputs by getting projection features from input and concatenating them.\n\n Args:\n x (List[torch.Tensor]): List of feature maps from the backbone.\n\n Returns:\n feats (torch.Tensor): Processed features.\n shapes (list): List of feature map shapes.\n \"\"\"\n # Get projection features\n x = [self.input_proj[i](feat) for i, feat in enumerate(x)]\n # Get encoder inputs\n feats = []\n shapes = []\n for feat in x:\n h, w = feat.shape[2:]\n # [b, c, h, w] -> [b, h*w, c]\n feats.append(feat.flatten(2).permute(0, 2, 1))\n # [nl, 2]\n shapes.append([h, w])\n\n # [b, h*w, c]\n feats = torch.cat(feats, 1)\n return feats, shapes\n\n def _get_decoder_input(\n self,\n feats: torch.Tensor,\n shapes: List[List[int]],\n dn_embed: Optional[torch.Tensor] = None,\n dn_bbox: Optional[torch.Tensor] = None,\n ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n \"\"\"\n Generate and prepare the input required for the decoder from the provided features and shapes.\n\n Args:\n feats (torch.Tensor): Processed features from encoder.\n shapes (list): List of feature map shapes.\n dn_embed (torch.Tensor, optional): Denoising embeddings.\n dn_bbox (torch.Tensor, optional): Denoising bounding boxes.\n\n Returns:\n embeddings (torch.Tensor): Query embeddings for decoder.\n refer_bbox (torch.Tensor): Reference bounding boxes.\n enc_bboxes (torch.Tensor): Encoded bounding boxes.\n enc_scores (torch.Tensor): Encoded scores.\n \"\"\"\n bs = feats.shape[0]\n # Prepare input for decoder\n anchors, valid_mask = self._generate_anchors(shapes, dtype=feats.dtype, device=feats.device)\n features = self.enc_output(valid_mask * feats) # bs, h*w, 256\n\n enc_outputs_scores = self.enc_score_head(features) # (bs, h*w, nc)\n\n # Query selection\n # (bs, num_queries)\n topk_ind = torch.topk(enc_outputs_scores.max(-1).values, self.num_queries, dim=1).indices.view(-1)\n # (bs, num_queries)\n batch_ind = torch.arange(end=bs, dtype=topk_ind.dtype).unsqueeze(-1).repeat(1, self.num_queries).view(-1)\n\n # (bs, num_queries, 256)\n top_k_features = features[batch_ind, topk_ind].view(bs, self.num_queries, -1)\n # (bs, num_queries, 4)\n top_k_anchors = anchors[:, topk_ind].view(bs, self.num_queries, -1)\n\n # Dynamic anchors + static content\n refer_bbox = self.enc_bbox_head(top_k_features) + top_k_anchors\n\n enc_bboxes = refer_bbox.sigmoid()\n if dn_bbox is not None:\n refer_bbox = torch.cat([dn_bbox, refer_bbox], 1)\n enc_scores = enc_outputs_scores[batch_ind, topk_ind].view(bs, self.num_queries, -1)\n\n embeddings = self.tgt_embed.weight.unsqueeze(0).repeat(bs, 1, 1) if self.learnt_init_query else top_k_features\n if self.training:\n refer_bbox = refer_bbox.detach()\n if not self.learnt_init_query:\n embeddings = embeddings.detach()\n if dn_embed is not None:\n embeddings = torch.cat([dn_embed, embeddings], 1)\n\n return embeddings, refer_bbox, enc_bboxes, enc_scores\n\n def _reset_parameters(self):\n \"\"\"Initialize or reset the parameters of the model's various components with predefined weights and biases.\"\"\"\n # Class and bbox head init\n bias_cls = bias_init_with_prob(0.01) / 80 * self.nc\n # NOTE: the weight initialization in `linear_init` would cause NaN when training with custom datasets.\n # linear_init(self.enc_score_head)\n constant_(self.enc_score_head.bias, bias_cls)\n constant_(self.enc_bbox_head.layers[-1].weight, 0.0)\n constant_(self.enc_bbox_head.layers[-1].bias, 0.0)\n for cls_, reg_ in zip(self.dec_score_head, self.dec_bbox_head):\n # linear_init(cls_)\n constant_(cls_.bias, bias_cls)\n constant_(reg_.layers[-1].weight, 0.0)\n constant_(reg_.layers[-1].bias, 0.0)\n\n linear_init(self.enc_output[0])\n xavier_uniform_(self.enc_output[0].weight)\n if self.learnt_init_query:\n xavier_uniform_(self.tgt_embed.weight)\n xavier_uniform_(self.query_pos_head.layers[0].weight)\n xavier_uniform_(self.query_pos_head.layers[1].weight)\n for layer in self.input_proj:\n xavier_uniform_(layer[0].weight)", "chunk_type": "class", "name": "RTDETRDecoder", "file_path": "ultralytics\\ultralytics\\nn\\modules\\head.py", "start_line": 853, "end_line": 1172, "start_col": 0, "end_col": 44, "parent_name": null, "docstring": "Real-Time Deformable Transformer Decoder (RTDETRDecoder) module for object detection.\n\nThis decoder module utilizes Transformer architecture along with deformable convolutions to predict bounding boxes\nand class labels for objects in an image. It integrates features from multiple layers and runs through a series of\nTransformer decoder layers to output the final predictions.\n\nAttributes:\n export (bool): Export mode flag.\n hidden_dim (int): Dimension of hidden layers.\n nhead (int): Number of heads in multi-head attention.\n nl (int): Number of feature levels.\n nc (int): Number of classes.\n num_queries (int): Number of query points.\n num_decoder_layers (int): Number of decoder layers.\n input_proj (nn.ModuleList): Input projection layers for backbone features.\n decoder (DeformableTransformerDecoder): Transformer decoder module.\n denoising_class_embed (nn.Embedding): Class embeddings for denoising.\n num_denoising (int): Number of denoising queries.\n label_noise_ratio (float): Label noise ratio for training.\n box_noise_scale (float): Box noise scale for training.\n learnt_init_query (bool): Whether to learn initial query embeddings.\n tgt_embed (nn.Embedding): Target embeddings for queries.\n query_pos_head (MLP): Query position head.\n enc_output (nn.Sequential): Encoder output layers.\n enc_score_head (nn.Linear): Encoder score prediction head.\n enc_bbox_head (MLP): Encoder bbox prediction head.\n dec_score_head (nn.ModuleList): Decoder score prediction heads.\n dec_bbox_head (nn.ModuleList): Decoder bbox prediction heads.\n\nMethods:\n forward: Run forward pass and return bounding box and classification scores.\n\nExamples:\n Create an RTDETRDecoder\n >>> decoder = RTDETRDecoder(nc=80, ch=(512, 1024, 2048), hd=256, nq=300)\n >>> x = [torch.randn(1, 512, 64, 64), torch.randn(1, 1024, 32, 32), torch.randn(1, 2048, 16, 16)]\n >>> outputs = decoder(x)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "copy", "math", "typing.List", "typing.Optional", "typing.Tuple", "typing.Union", "torch", "torch.nn", "torch.nn.functional", "torch.nn.init.constant_", "torch.nn.init.xavier_uniform_", "ultralytics.utils.tal.TORCH_1_10", "ultralytics.utils.tal.dist2bbox", "ultralytics.utils.tal.dist2rbox", "ultralytics.utils.tal.make_anchors", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "ultralytics.utils.torch_utils.smart_inference_mode", "block.DFL", "block.SAVPE", "block.BNContrastiveHead", "block.ContrastiveHead", "block.Proto", "block.Residual", "block.SwiGLUFFN", "conv.Conv", "conv.DWConv", "transformer.MLP", "transformer.DeformableTransformerDecoder", "transformer.DeformableTransformerDecoderLayer", "utils.bias_init_with_prob", "utils.linear_init", "ultralytics.models.utils.ops.get_cdn_group", "nn.Module" ], "chunk_id": "class_RTDETRDecoder_7ce5df20" }, { "content": "class v10Detect(Detect):\n \"\"\"\n v10 Detection head from https://arxiv.org/pdf/2405.14458.\n\n This class implements the YOLOv10 detection head with dual-assignment training and consistent dual predictions\n for improved efficiency and performance.\n\n Attributes:\n end2end (bool): End-to-end detection mode.\n max_det (int): Maximum number of detections.\n cv3 (nn.ModuleList): Light classification head layers.\n one2one_cv3 (nn.ModuleList): One-to-one classification head layers.\n\n Methods:\n __init__: Initialize the v10Detect object with specified number of classes and input channels.\n forward: Perform forward pass of the v10Detect module.\n bias_init: Initialize biases of the Detect module.\n fuse: Remove the one2many head for inference optimization.\n\n Examples:\n Create a v10Detect head\n >>> v10_detect = v10Detect(nc=80, ch=(256, 512, 1024))\n >>> x = [torch.randn(1, 256, 80, 80), torch.randn(1, 512, 40, 40), torch.randn(1, 1024, 20, 20)]\n >>> outputs = v10_detect(x)\n \"\"\"\n\n end2end = True\n\n def __init__(self, nc: int = 80, ch: Tuple = ()):\n \"\"\"\n Initialize the v10Detect object with the specified number of classes and input channels.\n\n Args:\n nc (int): Number of classes.\n ch (tuple): Tuple of channel sizes from backbone feature maps.\n \"\"\"\n super().__init__(nc, ch)\n c3 = max(ch[0], min(self.nc, 100)) # channels\n # Light cls head\n self.cv3 = nn.ModuleList(\n nn.Sequential(\n nn.Sequential(Conv(x, x, 3, g=x), Conv(x, c3, 1)),\n nn.Sequential(Conv(c3, c3, 3, g=c3), Conv(c3, c3, 1)),\n nn.Conv2d(c3, self.nc, 1),\n )\n for x in ch\n )\n self.one2one_cv3 = copy.deepcopy(self.cv3)\n\n def fuse(self):\n \"\"\"Remove the one2many head for inference optimization.\"\"\"\n self.cv2 = self.cv3 = nn.ModuleList([nn.Identity()] * self.nl)", "chunk_type": "class", "name": "v10Detect", "file_path": "ultralytics\\ultralytics\\nn\\modules\\head.py", "start_line": 1175, "end_line": 1226, "start_col": 0, "end_col": 70, "parent_name": null, "docstring": "v10 Detection head from https://arxiv.org/pdf/2405.14458.\n\nThis class implements the YOLOv10 detection head with dual-assignment training and consistent dual predictions\nfor improved efficiency and performance.\n\nAttributes:\n end2end (bool): End-to-end detection mode.\n max_det (int): Maximum number of detections.\n cv3 (nn.ModuleList): Light classification head layers.\n one2one_cv3 (nn.ModuleList): One-to-one classification head layers.\n\nMethods:\n __init__: Initialize the v10Detect object with specified number of classes and input channels.\n forward: Perform forward pass of the v10Detect module.\n bias_init: Initialize biases of the Detect module.\n fuse: Remove the one2many head for inference optimization.\n\nExamples:\n Create a v10Detect head\n >>> v10_detect = v10Detect(nc=80, ch=(256, 512, 1024))\n >>> x = [torch.randn(1, 256, 80, 80), torch.randn(1, 512, 40, 40), torch.randn(1, 1024, 20, 20)]\n >>> outputs = v10_detect(x)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "copy", "math", "typing.List", "typing.Optional", "typing.Tuple", "typing.Union", "torch", "torch.nn", "torch.nn.functional", "torch.nn.init.constant_", "torch.nn.init.xavier_uniform_", "ultralytics.utils.tal.TORCH_1_10", "ultralytics.utils.tal.dist2bbox", "ultralytics.utils.tal.dist2rbox", "ultralytics.utils.tal.make_anchors", "ultralytics.utils.torch_utils.fuse_conv_and_bn", "ultralytics.utils.torch_utils.smart_inference_mode", "block.DFL", "block.SAVPE", "block.BNContrastiveHead", "block.ContrastiveHead", "block.Proto", "block.Residual", "block.SwiGLUFFN", "conv.Conv", "conv.DWConv", "transformer.MLP", "transformer.DeformableTransformerDecoder", "transformer.DeformableTransformerDecoderLayer", "utils.bias_init_with_prob", "utils.linear_init", "ultralytics.models.utils.ops.get_cdn_group", "Detect" ], "chunk_id": "class_v10Detect_6bfc5509" }, { "content": "import math", "chunk_type": "import", "name": "math", "file_path": "ultralytics\\ultralytics\\nn\\modules\\transformer.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 11, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_math_b656c83b" }, { "content": "from typing import List, Optional", "chunk_type": "import", "name": "List, Optional", "file_path": "ultralytics\\ultralytics\\nn\\modules\\transformer.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 33, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_List, Optional_bceca4d3" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\ultralytics\\nn\\modules\\transformer.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_2c3e4f94" }, { "content": "import torch.nn as nn", "chunk_type": "import", "name": "torch.nn", "file_path": "ultralytics\\ultralytics\\nn\\modules\\transformer.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 21, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch.nn_15d9547d" }, { "content": "import torch.nn.functional as F", "chunk_type": "import", "name": "torch.nn.functional", "file_path": "ultralytics\\ultralytics\\nn\\modules\\transformer.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 31, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch.nn.functional_6458c49b" }, { "content": "from torch.nn.init import constant_, xavier_uniform_", "chunk_type": "import", "name": "constant_, xavier_uniform_", "file_path": "ultralytics\\ultralytics\\nn\\modules\\transformer.py", "start_line": 10, "end_line": 10, "start_col": 0, "end_col": 52, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_constant_, xavier_uniform__cf516011" }, { "content": "from .conv import Conv", "chunk_type": "import", "name": "Conv", "file_path": "ultralytics\\ultralytics\\nn\\modules\\transformer.py", "start_line": 12, "end_line": 12, "start_col": 0, "end_col": 22, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Conv_28de563c" }, { "content": "from .utils import _get_clones, inverse_sigmoid, multi_scale_deformable_attn_pytorch", "chunk_type": "import", "name": "_get_clones, inverse_sigmoid, multi_scale_deformable_attn_pytorch", "file_path": "ultralytics\\ultralytics\\nn\\modules\\transformer.py", "start_line": 13, "end_line": 13, "start_col": 0, "end_col": 84, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import__get_clones, inverse_sigmoid, multi_scale_deformable_attn_pytorch_60657b4f" }, { "content": "__all__ = (\n \"TransformerEncoderLayer\",\n \"TransformerLayer\",\n \"TransformerBlock\",\n \"MLPBlock\",\n \"LayerNorm2d\",\n \"AIFI\",\n \"DeformableTransformerDecoder\",\n \"DeformableTransformerDecoderLayer\",\n \"MSDeformAttn\",\n \"MLP\",\n)", "chunk_type": "variable", "name": "__all__", "file_path": "ultralytics\\ultralytics\\nn\\modules\\transformer.py", "start_line": 15, "end_line": 26, "start_col": 0, "end_col": 1, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable___all___515bdabe" }, { "content": "class TransformerEncoderLayer(nn.Module):\n \"\"\"\n A single layer of the transformer encoder.\n\n This class implements a standard transformer encoder layer with multi-head attention and feedforward network,\n supporting both pre-normalization and post-normalization configurations.\n\n Attributes:\n ma (nn.MultiheadAttention): Multi-head attention module.\n fc1 (nn.Linear): First linear layer in the feedforward network.\n fc2 (nn.Linear): Second linear layer in the feedforward network.\n norm1 (nn.LayerNorm): Layer normalization after attention.\n norm2 (nn.LayerNorm): Layer normalization after feedforward network.\n dropout (nn.Dropout): Dropout layer for the feedforward network.\n dropout1 (nn.Dropout): Dropout layer after attention.\n dropout2 (nn.Dropout): Dropout layer after feedforward network.\n act (nn.Module): Activation function.\n normalize_before (bool): Whether to apply normalization before attention and feedforward.\n \"\"\"\n\n def __init__(\n self,\n c1: int,\n cm: int = 2048,\n num_heads: int = 8,\n dropout: float = 0.0,\n act: nn.Module = nn.GELU(),\n normalize_before: bool = False,\n ):\n \"\"\"\n Initialize the TransformerEncoderLayer with specified parameters.\n\n Args:\n c1 (int): Input dimension.\n cm (int): Hidden dimension in the feedforward network.\n num_heads (int): Number of attention heads.\n dropout (float): Dropout probability.\n act (nn.Module): Activation function.\n normalize_before (bool): Whether to apply normalization before attention and feedforward.\n \"\"\"\n super().__init__()\n from ...utils.torch_utils import TORCH_1_9\n\n if not TORCH_1_9:\n raise ModuleNotFoundError(\n \"TransformerEncoderLayer() requires torch>=1.9 to use nn.MultiheadAttention(batch_first=True).\"\n )\n self.ma = nn.MultiheadAttention(c1, num_heads, dropout=dropout, batch_first=True)\n # Implementation of Feedforward model\n self.fc1 = nn.Linear(c1, cm)\n self.fc2 = nn.Linear(cm, c1)\n\n self.norm1 = nn.LayerNorm(c1)\n self.norm2 = nn.LayerNorm(c1)\n self.dropout = nn.Dropout(dropout)\n self.dropout1 = nn.Dropout(dropout)\n self.dropout2 = nn.Dropout(dropout)\n\n self.act = act\n self.normalize_before = normalize_before\n\n @staticmethod\n def with_pos_embed(tensor: torch.Tensor, pos: Optional[torch.Tensor] = None) -> torch.Tensor:\n \"\"\"Add position embeddings to the tensor if provided.\"\"\"\n return tensor if pos is None else tensor + pos\n\n def forward_post(\n self,\n src: torch.Tensor,\n src_mask: Optional[torch.Tensor] = None,\n src_key_padding_mask: Optional[torch.Tensor] = None,\n pos: Optional[torch.Tensor] = None,\n ) -> torch.Tensor:\n \"\"\"\n Perform forward pass with post-normalization.\n\n Args:\n src (torch.Tensor): Input tensor.\n src_mask (torch.Tensor, optional): Mask for the src sequence.\n src_key_padding_mask (torch.Tensor, optional): Mask for the src keys per batch.\n pos (torch.Tensor, optional): Positional encoding.\n\n Returns:\n (torch.Tensor): Output tensor after attention and feedforward.\n \"\"\"\n q = k = self.with_pos_embed(src, pos)\n src2 = self.ma(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]\n src = src + self.dropout1(src2)\n src = self.norm1(src)\n src2 = self.fc2(self.dropout(self.act(self.fc1(src))))\n src = src + self.dropout2(src2)\n return self.norm2(src)\n\n def forward_pre(\n self,\n src: torch.Tensor,\n src_mask: Optional[torch.Tensor] = None,\n src_key_padding_mask: Optional[torch.Tensor] = None,\n pos: Optional[torch.Tensor] = None,\n ) -> torch.Tensor:\n \"\"\"\n Perform forward pass with pre-normalization.\n\n Args:\n src (torch.Tensor): Input tensor.\n src_mask (torch.Tensor, optional): Mask for the src sequence.\n src_key_padding_mask (torch.Tensor, optional): Mask for the src keys per batch.\n pos (torch.Tensor, optional): Positional encoding.\n\n Returns:\n (torch.Tensor): Output tensor after attention and feedforward.\n \"\"\"\n src2 = self.norm1(src)\n q = k = self.with_pos_embed(src2, pos)\n src2 = self.ma(q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]\n src = src + self.dropout1(src2)\n src2 = self.norm2(src)\n src2 = self.fc2(self.dropout(self.act(self.fc1(src2))))\n return src + self.dropout2(src2)\n\n def forward(\n self,\n src: torch.Tensor,\n src_mask: Optional[torch.Tensor] = None,\n src_key_padding_mask: Optional[torch.Tensor] = None,\n pos: Optional[torch.Tensor] = None,\n ) -> torch.Tensor:\n \"\"\"\n Forward propagate the input through the encoder module.\n\n Args:\n src (torch.Tensor): Input tensor.\n src_mask (torch.Tensor, optional): Mask for the src sequence.\n src_key_padding_mask (torch.Tensor, optional): Mask for the src keys per batch.\n pos (torch.Tensor, optional): Positional encoding.\n\n Returns:\n (torch.Tensor): Output tensor after transformer encoder layer.\n \"\"\"\n if self.normalize_before:\n return self.forward_pre(src, src_mask, src_key_padding_mask, pos)\n return self.forward_post(src, src_mask, src_key_padding_mask, pos)", "chunk_type": "class", "name": "TransformerEncoderLayer", "file_path": "ultralytics\\ultralytics\\nn\\modules\\transformer.py", "start_line": 29, "end_line": 170, "start_col": 0, "end_col": 74, "parent_name": null, "docstring": "A single layer of the transformer encoder.\n\nThis class implements a standard transformer encoder layer with multi-head attention and feedforward network,\nsupporting both pre-normalization and post-normalization configurations.\n\nAttributes:\n ma (nn.MultiheadAttention): Multi-head attention module.\n fc1 (nn.Linear): First linear layer in the feedforward network.\n fc2 (nn.Linear): Second linear layer in the feedforward network.\n norm1 (nn.LayerNorm): Layer normalization after attention.\n norm2 (nn.LayerNorm): Layer normalization after feedforward network.\n dropout (nn.Dropout): Dropout layer for the feedforward network.\n dropout1 (nn.Dropout): Dropout layer after attention.\n dropout2 (nn.Dropout): Dropout layer after feedforward network.\n act (nn.Module): Activation function.\n normalize_before (bool): Whether to apply normalization before attention and feedforward.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "math", "typing.List", "typing.Optional", "torch", "torch.nn", "torch.nn.functional", "torch.nn.init.constant_", "torch.nn.init.xavier_uniform_", "conv.Conv", "utils._get_clones", "utils.inverse_sigmoid", "utils.multi_scale_deformable_attn_pytorch", "utils.torch_utils.TORCH_1_9", "nn.Module" ], "chunk_id": "class_TransformerEncoderLayer_836709d4" }, { "content": "class AIFI(TransformerEncoderLayer):\n \"\"\"\n AIFI transformer layer for 2D data with positional embeddings.\n\n This class extends TransformerEncoderLayer to work with 2D feature maps by adding 2D sine-cosine positional\n embeddings and handling the spatial dimensions appropriately.\n \"\"\"\n\n def __init__(\n self,\n c1: int,\n cm: int = 2048,\n num_heads: int = 8,\n dropout: float = 0,\n act: nn.Module = nn.GELU(),\n normalize_before: bool = False,\n ):\n \"\"\"\n Initialize the AIFI instance with specified parameters.\n\n Args:\n c1 (int): Input dimension.\n cm (int): Hidden dimension in the feedforward network.\n num_heads (int): Number of attention heads.\n dropout (float): Dropout probability.\n act (nn.Module): Activation function.\n normalize_before (bool): Whether to apply normalization before attention and feedforward.\n \"\"\"\n super().__init__(c1, cm, num_heads, dropout, act, normalize_before)\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"\n Forward pass for the AIFI transformer layer.\n\n Args:\n x (torch.Tensor): Input tensor with shape [B, C, H, W].\n\n Returns:\n (torch.Tensor): Output tensor with shape [B, C, H, W].\n \"\"\"\n c, h, w = x.shape[1:]\n pos_embed = self.build_2d_sincos_position_embedding(w, h, c)\n # Flatten [B, C, H, W] to [B, HxW, C]\n x = super().forward(x.flatten(2).permute(0, 2, 1), pos=pos_embed.to(device=x.device, dtype=x.dtype))\n return x.permute(0, 2, 1).view([-1, c, h, w]).contiguous()\n\n @staticmethod\n def build_2d_sincos_position_embedding(\n w: int, h: int, embed_dim: int = 256, temperature: float = 10000.0\n ) -> torch.Tensor:\n \"\"\"\n Build 2D sine-cosine position embedding.\n\n Args:\n w (int): Width of the feature map.\n h (int): Height of the feature map.\n embed_dim (int): Embedding dimension.\n temperature (float): Temperature for the sine/cosine functions.\n\n Returns:\n (torch.Tensor): Position embedding with shape [1, embed_dim, h*w].\n \"\"\"\n assert embed_dim % 4 == 0, \"Embed dimension must be divisible by 4 for 2D sin-cos position embedding\"\n grid_w = torch.arange(w, dtype=torch.float32)\n grid_h = torch.arange(h, dtype=torch.float32)\n grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing=\"ij\")\n pos_dim = embed_dim // 4\n omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim\n omega = 1.0 / (temperature**omega)\n\n out_w = grid_w.flatten()[..., None] @ omega[None]\n out_h = grid_h.flatten()[..., None] @ omega[None]\n\n return torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], 1)[None]", "chunk_type": "class", "name": "AIFI", "file_path": "ultralytics\\ultralytics\\nn\\modules\\transformer.py", "start_line": 173, "end_line": 246, "start_col": 0, "end_col": 107, "parent_name": null, "docstring": "AIFI transformer layer for 2D data with positional embeddings.\n\nThis class extends TransformerEncoderLayer to work with 2D feature maps by adding 2D sine-cosine positional\nembeddings and handling the spatial dimensions appropriately.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "math", "typing.List", "typing.Optional", "torch", "torch.nn", "torch.nn.functional", "torch.nn.init.constant_", "torch.nn.init.xavier_uniform_", "conv.Conv", "utils._get_clones", "utils.inverse_sigmoid", "utils.multi_scale_deformable_attn_pytorch", "utils.torch_utils.TORCH_1_9", "TransformerEncoderLayer" ], "chunk_id": "class_AIFI_8f6b96f3" }, { "content": "class TransformerLayer(nn.Module):\n \"\"\"Transformer layer https://arxiv.org/abs/2010.11929 (LayerNorm layers removed for better performance).\"\"\"\n\n def __init__(self, c: int, num_heads: int):\n \"\"\"\n Initialize a self-attention mechanism using linear transformations and multi-head attention.\n\n Args:\n c (int): Input and output channel dimension.\n num_heads (int): Number of attention heads.\n \"\"\"\n super().__init__()\n self.q = nn.Linear(c, c, bias=False)\n self.k = nn.Linear(c, c, bias=False)\n self.v = nn.Linear(c, c, bias=False)\n self.ma = nn.MultiheadAttention(embed_dim=c, num_heads=num_heads)\n self.fc1 = nn.Linear(c, c, bias=False)\n self.fc2 = nn.Linear(c, c, bias=False)\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"\n Apply a transformer block to the input x and return the output.\n\n Args:\n x (torch.Tensor): Input tensor.\n\n Returns:\n (torch.Tensor): Output tensor after transformer layer.\n \"\"\"\n x = self.ma(self.q(x), self.k(x), self.v(x))[0] + x\n return self.fc2(self.fc1(x)) + x", "chunk_type": "class", "name": "TransformerLayer", "file_path": "ultralytics\\ultralytics\\nn\\modules\\transformer.py", "start_line": 249, "end_line": 279, "start_col": 0, "end_col": 40, "parent_name": null, "docstring": "Transformer layer https://arxiv.org/abs/2010.11929 (LayerNorm layers removed for better performance).", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "math", "typing.List", "typing.Optional", "torch", "torch.nn", "torch.nn.functional", "torch.nn.init.constant_", "torch.nn.init.xavier_uniform_", "conv.Conv", "utils._get_clones", "utils.inverse_sigmoid", "utils.multi_scale_deformable_attn_pytorch", "utils.torch_utils.TORCH_1_9", "nn.Module" ], "chunk_id": "class_TransformerLayer_f04f01ea" }, { "content": "class TransformerBlock(nn.Module):\n \"\"\"\n Vision Transformer block based on https://arxiv.org/abs/2010.11929.\n\n This class implements a complete transformer block with optional convolution layer for channel adjustment,\n learnable position embedding, and multiple transformer layers.\n\n Attributes:\n conv (Conv, optional): Convolution layer if input and output channels differ.\n linear (nn.Linear): Learnable position embedding.\n tr (nn.Sequential): Sequential container of transformer layers.\n c2 (int): Output channel dimension.\n \"\"\"\n\n def __init__(self, c1: int, c2: int, num_heads: int, num_layers: int):\n \"\"\"\n Initialize a Transformer module with position embedding and specified number of heads and layers.\n\n Args:\n c1 (int): Input channel dimension.\n c2 (int): Output channel dimension.\n num_heads (int): Number of attention heads.\n num_layers (int): Number of transformer layers.\n \"\"\"\n super().__init__()\n self.conv = None\n if c1 != c2:\n self.conv = Conv(c1, c2)\n self.linear = nn.Linear(c2, c2) # learnable position embedding\n self.tr = nn.Sequential(*(TransformerLayer(c2, num_heads) for _ in range(num_layers)))\n self.c2 = c2\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"\n Forward propagate the input through the transformer block.\n\n Args:\n x (torch.Tensor): Input tensor with shape [b, c1, w, h].\n\n Returns:\n (torch.Tensor): Output tensor with shape [b, c2, w, h].\n \"\"\"\n if self.conv is not None:\n x = self.conv(x)\n b, _, w, h = x.shape\n p = x.flatten(2).permute(2, 0, 1)\n return self.tr(p + self.linear(p)).permute(1, 2, 0).reshape(b, self.c2, w, h)", "chunk_type": "class", "name": "TransformerBlock", "file_path": "ultralytics\\ultralytics\\nn\\modules\\transformer.py", "start_line": 282, "end_line": 328, "start_col": 0, "end_col": 85, "parent_name": null, "docstring": "Vision Transformer block based on https://arxiv.org/abs/2010.11929.\n\nThis class implements a complete transformer block with optional convolution layer for channel adjustment,\nlearnable position embedding, and multiple transformer layers.\n\nAttributes:\n conv (Conv, optional): Convolution layer if input and output channels differ.\n linear (nn.Linear): Learnable position embedding.\n tr (nn.Sequential): Sequential container of transformer layers.\n c2 (int): Output channel dimension.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "math", "typing.List", "typing.Optional", "torch", "torch.nn", "torch.nn.functional", "torch.nn.init.constant_", "torch.nn.init.xavier_uniform_", "conv.Conv", "utils._get_clones", "utils.inverse_sigmoid", "utils.multi_scale_deformable_attn_pytorch", "utils.torch_utils.TORCH_1_9", "nn.Module" ], "chunk_id": "class_TransformerBlock_def174e0" }, { "content": "class MLPBlock(nn.Module):\n \"\"\"A single block of a multi-layer perceptron.\"\"\"\n\n def __init__(self, embedding_dim: int, mlp_dim: int, act=nn.GELU):\n \"\"\"\n Initialize the MLPBlock with specified embedding dimension, MLP dimension, and activation function.\n\n Args:\n embedding_dim (int): Input and output dimension.\n mlp_dim (int): Hidden dimension.\n act (nn.Module): Activation function.\n \"\"\"\n super().__init__()\n self.lin1 = nn.Linear(embedding_dim, mlp_dim)\n self.lin2 = nn.Linear(mlp_dim, embedding_dim)\n self.act = act()\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"\n Forward pass for the MLPBlock.\n\n Args:\n x (torch.Tensor): Input tensor.\n\n Returns:\n (torch.Tensor): Output tensor after MLP block.\n \"\"\"\n return self.lin2(self.act(self.lin1(x)))", "chunk_type": "class", "name": "MLPBlock", "file_path": "ultralytics\\ultralytics\\nn\\modules\\transformer.py", "start_line": 331, "end_line": 358, "start_col": 0, "end_col": 48, "parent_name": null, "docstring": "A single block of a multi-layer perceptron.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "math", "typing.List", "typing.Optional", "torch", "torch.nn", "torch.nn.functional", "torch.nn.init.constant_", "torch.nn.init.xavier_uniform_", "conv.Conv", "utils._get_clones", "utils.inverse_sigmoid", "utils.multi_scale_deformable_attn_pytorch", "utils.torch_utils.TORCH_1_9", "nn.Module" ], "chunk_id": "class_MLPBlock_cfb56b63" }, { "content": "class MLP(nn.Module):\n \"\"\"\n A simple multi-layer perceptron (also called FFN).\n\n This class implements a configurable MLP with multiple linear layers, activation functions, and optional\n sigmoid output activation.\n\n Attributes:\n num_layers (int): Number of layers in the MLP.\n layers (nn.ModuleList): List of linear layers.\n sigmoid (bool): Whether to apply sigmoid to the output.\n act (nn.Module): Activation function.\n \"\"\"\n\n def __init__(\n self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, act=nn.ReLU, sigmoid: bool = False\n ):\n \"\"\"\n Initialize the MLP with specified input, hidden, output dimensions and number of layers.\n\n Args:\n input_dim (int): Input dimension.\n hidden_dim (int): Hidden dimension.\n output_dim (int): Output dimension.\n num_layers (int): Number of layers.\n act (nn.Module): Activation function.\n sigmoid (bool): Whether to apply sigmoid to the output.\n \"\"\"\n super().__init__()\n self.num_layers = num_layers\n h = [hidden_dim] * (num_layers - 1)\n self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))\n self.sigmoid = sigmoid\n self.act = act()\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"\n Forward pass for the entire MLP.\n\n Args:\n x (torch.Tensor): Input tensor.\n\n Returns:\n (torch.Tensor): Output tensor after MLP.\n \"\"\"\n for i, layer in enumerate(self.layers):\n x = getattr(self, \"act\", nn.ReLU())(layer(x)) if i < self.num_layers - 1 else layer(x)\n return x.sigmoid() if getattr(self, \"sigmoid\", False) else x", "chunk_type": "class", "name": "MLP", "file_path": "ultralytics\\ultralytics\\nn\\modules\\transformer.py", "start_line": 361, "end_line": 408, "start_col": 0, "end_col": 68, "parent_name": null, "docstring": "A simple multi-layer perceptron (also called FFN).\n\nThis class implements a configurable MLP with multiple linear layers, activation functions, and optional\nsigmoid output activation.\n\nAttributes:\n num_layers (int): Number of layers in the MLP.\n layers (nn.ModuleList): List of linear layers.\n sigmoid (bool): Whether to apply sigmoid to the output.\n act (nn.Module): Activation function.", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "math", "typing.List", "typing.Optional", "torch", "torch.nn", "torch.nn.functional", "torch.nn.init.constant_", "torch.nn.init.xavier_uniform_", "conv.Conv", "utils._get_clones", "utils.inverse_sigmoid", "utils.multi_scale_deformable_attn_pytorch", "utils.torch_utils.TORCH_1_9", "nn.Module" ], "chunk_id": "class_MLP_7a99ec10" }, { "content": "class LayerNorm2d(nn.Module):\n \"\"\"\n 2D Layer Normalization module inspired by Detectron2 and ConvNeXt implementations.\n\n This class implements layer normalization for 2D feature maps, normalizing across the channel dimension\n while preserving spatial dimensions.\n\n Attributes:\n weight (nn.Parameter): Learnable scale parameter.\n bias (nn.Parameter): Learnable bias parameter.\n eps (float): Small constant for numerical stability.\n\n References:\n https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py\n https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext.py\n \"\"\"\n\n def __init__(self, num_channels: int, eps: float = 1e-6):\n \"\"\"\n Initialize LayerNorm2d with the given parameters.\n\n Args:\n num_channels (int): Number of channels in the input.\n eps (float): Small constant for numerical stability.\n \"\"\"\n super().__init__()\n self.weight = nn.Parameter(torch.ones(num_channels))\n self.bias = nn.Parameter(torch.zeros(num_channels))\n self.eps = eps\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"\n Perform forward pass for 2D layer normalization.\n\n Args:\n x (torch.Tensor): Input tensor.\n\n Returns:\n (torch.Tensor): Normalized output tensor.\n \"\"\"\n u = x.mean(1, keepdim=True)\n s = (x - u).pow(2).mean(1, keepdim=True)\n x = (x - u) / torch.sqrt(s + self.eps)\n return self.weight[:, None, None] * x + self.bias[:, None, None]", "chunk_type": "class", "name": "LayerNorm2d", "file_path": "ultralytics\\ultralytics\\nn\\modules\\transformer.py", "start_line": 411, "end_line": 454, "start_col": 0, "end_col": 72, "parent_name": null, "docstring": "2D Layer Normalization module inspired by Detectron2 and ConvNeXt implementations.\n\nThis class implements layer normalization for 2D feature maps, normalizing across the channel dimension\nwhile preserving spatial dimensions.\n\nAttributes:\n weight (nn.Parameter): Learnable scale parameter.\n bias (nn.Parameter): Learnable bias parameter.\n eps (float): Small constant for numerical stability.\n\nReferences:\n https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py\n https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext.py", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "math", "typing.List", "typing.Optional", "torch", "torch.nn", "torch.nn.functional", "torch.nn.init.constant_", "torch.nn.init.xavier_uniform_", "conv.Conv", "utils._get_clones", "utils.inverse_sigmoid", "utils.multi_scale_deformable_attn_pytorch", "utils.torch_utils.TORCH_1_9", "nn.Module" ], "chunk_id": "class_LayerNorm2d_933f0d90" }, { "content": "class MSDeformAttn(nn.Module):\n \"\"\"\n Multiscale Deformable Attention Module based on Deformable-DETR and PaddleDetection implementations.\n\n This module implements multiscale deformable attention that can attend to features at multiple scales\n with learnable sampling locations and attention weights.\n\n Attributes:\n im2col_step (int): Step size for im2col operations.\n d_model (int): Model dimension.\n n_levels (int): Number of feature levels.\n n_heads (int): Number of attention heads.\n n_points (int): Number of sampling points per attention head per feature level.\n sampling_offsets (nn.Linear): Linear layer for generating sampling offsets.\n attention_weights (nn.Linear): Linear layer for generating attention weights.\n value_proj (nn.Linear): Linear layer for projecting values.\n output_proj (nn.Linear): Linear layer for projecting output.\n\n References:\n https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/modules/ms_deform_attn.py\n \"\"\"\n\n def __init__(self, d_model: int = 256, n_levels: int = 4, n_heads: int = 8, n_points: int = 4):\n \"\"\"\n Initialize MSDeformAttn with the given parameters.\n\n Args:\n d_model (int): Model dimension.\n n_levels (int): Number of feature levels.\n n_heads (int): Number of attention heads.\n n_points (int): Number of sampling points per attention head per feature level.\n \"\"\"\n super().__init__()\n if d_model % n_heads != 0:\n raise ValueError(f\"d_model must be divisible by n_heads, but got {d_model} and {n_heads}\")\n _d_per_head = d_model // n_heads\n # Better to set _d_per_head to a power of 2 which is more efficient in a CUDA implementation\n assert _d_per_head * n_heads == d_model, \"`d_model` must be divisible by `n_heads`\"\n\n self.im2col_step = 64\n\n self.d_model = d_model\n self.n_levels = n_levels\n self.n_heads = n_heads\n self.n_points = n_points\n\n self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)\n self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)\n self.value_proj = nn.Linear(d_model, d_model)\n self.output_proj = nn.Linear(d_model, d_model)\n\n self._reset_parameters()\n\n def _reset_parameters(self):\n \"\"\"Reset module parameters.\"\"\"\n constant_(self.sampling_offsets.weight.data, 0.0)\n thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)\n grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)\n grid_init = (\n (grid_init / grid_init.abs().max(-1, keepdim=True)[0])\n .view(self.n_heads, 1, 1, 2)\n .repeat(1, self.n_levels, self.n_points, 1)\n )\n for i in range(self.n_points):\n grid_init[:, :, i, :] *= i + 1\n with torch.no_grad():\n self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))\n constant_(self.attention_weights.weight.data, 0.0)\n constant_(self.attention_weights.bias.data, 0.0)\n xavier_uniform_(self.value_proj.weight.data)\n constant_(self.value_proj.bias.data, 0.0)\n xavier_uniform_(self.output_proj.weight.data)\n constant_(self.output_proj.bias.data, 0.0)\n\n def forward(\n self,\n query: torch.Tensor,\n refer_bbox: torch.Tensor,\n value: torch.Tensor,\n value_shapes: List,\n value_mask: Optional[torch.Tensor] = None,\n ) -> torch.Tensor:\n \"\"\"\n Perform forward pass for multiscale deformable attention.\n\n Args:\n query (torch.Tensor): Query tensor with shape [bs, query_length, C].\n refer_bbox (torch.Tensor): Reference bounding boxes with shape [bs, query_length, n_levels, 2],\n range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area.\n value (torch.Tensor): Value tensor with shape [bs, value_length, C].\n value_shapes (list): List with shape [n_levels, 2], [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})].\n value_mask (torch.Tensor, optional): Mask tensor with shape [bs, value_length], True for non-padding\n elements, False for padding elements.\n\n Returns:\n (torch.Tensor): Output tensor with shape [bs, Length_{query}, C].\n\n References:\n https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/transformers/deformable_transformer.py\n \"\"\"\n bs, len_q = query.shape[:2]\n len_v = value.shape[1]\n assert sum(s[0] * s[1] for s in value_shapes) == len_v\n\n value = self.value_proj(value)\n if value_mask is not None:\n value = value.masked_fill(value_mask[..., None], float(0))\n value = value.view(bs, len_v, self.n_heads, self.d_model // self.n_heads)\n sampling_offsets = self.sampling_offsets(query).view(bs, len_q, self.n_heads, self.n_levels, self.n_points, 2)\n attention_weights = self.attention_weights(query).view(bs, len_q, self.n_heads, self.n_levels * self.n_points)\n attention_weights = F.softmax(attention_weights, -1).view(bs, len_q, self.n_heads, self.n_levels, self.n_points)\n # N, Len_q, n_heads, n_levels, n_points, 2\n num_points = refer_bbox.shape[-1]\n if num_points == 2:\n offset_normalizer = torch.as_tensor(value_shapes, dtype=query.dtype, device=query.device).flip(-1)\n add = sampling_offsets / offset_normalizer[None, None, None, :, None, :]\n sampling_locations = refer_bbox[:, :, None, :, None, :] + add\n elif num_points == 4:\n add = sampling_offsets / self.n_points * refer_bbox[:, :, None, :, None, 2:] * 0.5\n sampling_locations = refer_bbox[:, :, None, :, None, :2] + add\n else:\n raise ValueError(f\"Last dim of reference_points must be 2 or 4, but got {num_points}.\")\n output = multi_scale_deformable_attn_pytorch(value, value_shapes, sampling_locations, attention_weights)\n return self.output_proj(output)", "chunk_type": "class", "name": "MSDeformAttn", "file_path": "ultralytics\\ultralytics\\nn\\modules\\transformer.py", "start_line": 457, "end_line": 580, "start_col": 0, "end_col": 39, "parent_name": null, "docstring": "Multiscale Deformable Attention Module based on Deformable-DETR and PaddleDetection implementations.\n\nThis module implements multiscale deformable attention that can attend to features at multiple scales\nwith learnable sampling locations and attention weights.\n\nAttributes:\n im2col_step (int): Step size for im2col operations.\n d_model (int): Model dimension.\n n_levels (int): Number of feature levels.\n n_heads (int): Number of attention heads.\n n_points (int): Number of sampling points per attention head per feature level.\n sampling_offsets (nn.Linear): Linear layer for generating sampling offsets.\n attention_weights (nn.Linear): Linear layer for generating attention weights.\n value_proj (nn.Linear): Linear layer for projecting values.\n output_proj (nn.Linear): Linear layer for projecting output.\n\nReferences:\n https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/modules/ms_deform_attn.py", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "math", "typing.List", "typing.Optional", "torch", "torch.nn", "torch.nn.functional", "torch.nn.init.constant_", "torch.nn.init.xavier_uniform_", "conv.Conv", "utils._get_clones", "utils.inverse_sigmoid", "utils.multi_scale_deformable_attn_pytorch", "utils.torch_utils.TORCH_1_9", "nn.Module" ], "chunk_id": "class_MSDeformAttn_b3c584f0" }, { "content": "class DeformableTransformerDecoderLayer(nn.Module):\n \"\"\"\n Deformable Transformer Decoder Layer inspired by PaddleDetection and Deformable-DETR implementations.\n\n This class implements a single decoder layer with self-attention, cross-attention using multiscale deformable\n attention, and a feedforward network.\n\n Attributes:\n self_attn (nn.MultiheadAttention): Self-attention module.\n dropout1 (nn.Dropout): Dropout after self-attention.\n norm1 (nn.LayerNorm): Layer normalization after self-attention.\n cross_attn (MSDeformAttn): Cross-attention module.\n dropout2 (nn.Dropout): Dropout after cross-attention.\n norm2 (nn.LayerNorm): Layer normalization after cross-attention.\n linear1 (nn.Linear): First linear layer in the feedforward network.\n act (nn.Module): Activation function.\n dropout3 (nn.Dropout): Dropout in the feedforward network.\n linear2 (nn.Linear): Second linear layer in the feedforward network.\n dropout4 (nn.Dropout): Dropout after the feedforward network.\n norm3 (nn.LayerNorm): Layer normalization after the feedforward network.\n\n References:\n https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/transformers/deformable_transformer.py\n https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/deformable_transformer.py\n \"\"\"\n\n def __init__(\n self,\n d_model: int = 256,\n n_heads: int = 8,\n d_ffn: int = 1024,\n dropout: float = 0.0,\n act: nn.Module = nn.ReLU(),\n n_levels: int = 4,\n n_points: int = 4,\n ):\n \"\"\"\n Initialize the DeformableTransformerDecoderLayer with the given parameters.\n\n Args:\n d_model (int): Model dimension.\n n_heads (int): Number of attention heads.\n d_ffn (int): Dimension of the feedforward network.\n dropout (float): Dropout probability.\n act (nn.Module): Activation function.\n n_levels (int): Number of feature levels.\n n_points (int): Number of sampling points.\n \"\"\"\n super().__init__()\n\n # Self attention\n self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)\n self.dropout1 = nn.Dropout(dropout)\n self.norm1 = nn.LayerNorm(d_model)\n\n # Cross attention\n self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)\n self.dropout2 = nn.Dropout(dropout)\n self.norm2 = nn.LayerNorm(d_model)\n\n # FFN\n self.linear1 = nn.Linear(d_model, d_ffn)\n self.act = act\n self.dropout3 = nn.Dropout(dropout)\n self.linear2 = nn.Linear(d_ffn, d_model)\n self.dropout4 = nn.Dropout(dropout)\n self.norm3 = nn.LayerNorm(d_model)\n\n @staticmethod\n def with_pos_embed(tensor: torch.Tensor, pos: Optional[torch.Tensor]) -> torch.Tensor:\n \"\"\"Add positional embeddings to the input tensor, if provided.\"\"\"\n return tensor if pos is None else tensor + pos\n\n def forward_ffn(self, tgt: torch.Tensor) -> torch.Tensor:\n \"\"\"\n Perform forward pass through the Feed-Forward Network part of the layer.\n\n Args:\n tgt (torch.Tensor): Input tensor.\n\n Returns:\n (torch.Tensor): Output tensor after FFN.\n \"\"\"\n tgt2 = self.linear2(self.dropout3(self.act(self.linear1(tgt))))\n tgt = tgt + self.dropout4(tgt2)\n return self.norm3(tgt)\n\n def forward(\n self,\n embed: torch.Tensor,\n refer_bbox: torch.Tensor,\n feats: torch.Tensor,\n shapes: List,\n padding_mask: Optional[torch.Tensor] = None,\n attn_mask: Optional[torch.Tensor] = None,\n query_pos: Optional[torch.Tensor] = None,\n ) -> torch.Tensor:\n \"\"\"\n Perform the forward pass through the entire decoder layer.\n\n Args:\n embed (torch.Tensor): Input embeddings.\n refer_bbox (torch.Tensor): Reference bounding boxes.\n feats (torch.Tensor): Feature maps.\n shapes (list): Feature shapes.\n padding_mask (torch.Tensor, optional): Padding mask.\n attn_mask (torch.Tensor, optional): Attention mask.\n query_pos (torch.Tensor, optional): Query position embeddings.\n\n Returns:\n (torch.Tensor): Output tensor after decoder layer.\n \"\"\"\n # Self attention\n q = k = self.with_pos_embed(embed, query_pos)\n tgt = self.self_attn(q.transpose(0, 1), k.transpose(0, 1), embed.transpose(0, 1), attn_mask=attn_mask)[\n 0\n ].transpose(0, 1)\n embed = embed + self.dropout1(tgt)\n embed = self.norm1(embed)\n\n # Cross attention\n tgt = self.cross_attn(\n self.with_pos_embed(embed, query_pos), refer_bbox.unsqueeze(2), feats, shapes, padding_mask\n )\n embed = embed + self.dropout2(tgt)\n embed = self.norm2(embed)\n\n # FFN\n return self.forward_ffn(embed)", "chunk_type": "class", "name": "DeformableTransformerDecoderLayer", "file_path": "ultralytics\\ultralytics\\nn\\modules\\transformer.py", "start_line": 583, "end_line": 711, "start_col": 0, "end_col": 38, "parent_name": null, "docstring": "Deformable Transformer Decoder Layer inspired by PaddleDetection and Deformable-DETR implementations.\n\nThis class implements a single decoder layer with self-attention, cross-attention using multiscale deformable\nattention, and a feedforward network.\n\nAttributes:\n self_attn (nn.MultiheadAttention): Self-attention module.\n dropout1 (nn.Dropout): Dropout after self-attention.\n norm1 (nn.LayerNorm): Layer normalization after self-attention.\n cross_attn (MSDeformAttn): Cross-attention module.\n dropout2 (nn.Dropout): Dropout after cross-attention.\n norm2 (nn.LayerNorm): Layer normalization after cross-attention.\n linear1 (nn.Linear): First linear layer in the feedforward network.\n act (nn.Module): Activation function.\n dropout3 (nn.Dropout): Dropout in the feedforward network.\n linear2 (nn.Linear): Second linear layer in the feedforward network.\n dropout4 (nn.Dropout): Dropout after the feedforward network.\n norm3 (nn.LayerNorm): Layer normalization after the feedforward network.\n\nReferences:\n https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/transformers/deformable_transformer.py\n https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/deformable_transformer.py", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "math", "typing.List", "typing.Optional", "torch", "torch.nn", "torch.nn.functional", "torch.nn.init.constant_", "torch.nn.init.xavier_uniform_", "conv.Conv", "utils._get_clones", "utils.inverse_sigmoid", "utils.multi_scale_deformable_attn_pytorch", "utils.torch_utils.TORCH_1_9", "nn.Module" ], "chunk_id": "class_DeformableTransformerDecoderLayer_be16c3bf" }, { "content": "class DeformableTransformerDecoder(nn.Module):\n \"\"\"\n Deformable Transformer Decoder based on PaddleDetection implementation.\n\n This class implements a complete deformable transformer decoder with multiple decoder layers and prediction\n heads for bounding box regression and classification.\n\n Attributes:\n layers (nn.ModuleList): List of decoder layers.\n num_layers (int): Number of decoder layers.\n hidden_dim (int): Hidden dimension.\n eval_idx (int): Index of the layer to use during evaluation.\n\n References:\n https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/transformers/deformable_transformer.py\n \"\"\"\n\n def __init__(self, hidden_dim: int, decoder_layer: nn.Module, num_layers: int, eval_idx: int = -1):\n \"\"\"\n Initialize the DeformableTransformerDecoder with the given parameters.\n\n Args:\n hidden_dim (int): Hidden dimension.\n decoder_layer (nn.Module): Decoder layer module.\n num_layers (int): Number of decoder layers.\n eval_idx (int): Index of the layer to use during evaluation.\n \"\"\"\n super().__init__()\n self.layers = _get_clones(decoder_layer, num_layers)\n self.num_layers = num_layers\n self.hidden_dim = hidden_dim\n self.eval_idx = eval_idx if eval_idx >= 0 else num_layers + eval_idx\n\n def forward(\n self,\n embed: torch.Tensor, # decoder embeddings\n refer_bbox: torch.Tensor, # anchor\n feats: torch.Tensor, # image features\n shapes: List, # feature shapes\n bbox_head: nn.Module,\n score_head: nn.Module,\n pos_mlp: nn.Module,\n attn_mask: Optional[torch.Tensor] = None,\n padding_mask: Optional[torch.Tensor] = None,\n ):\n \"\"\"\n Perform the forward pass through the entire decoder.\n\n Args:\n embed (torch.Tensor): Decoder embeddings.\n refer_bbox (torch.Tensor): Reference bounding boxes.\n feats (torch.Tensor): Image features.\n shapes (list): Feature shapes.\n bbox_head (nn.Module): Bounding box prediction head.\n score_head (nn.Module): Score prediction head.\n pos_mlp (nn.Module): Position MLP.\n attn_mask (torch.Tensor, optional): Attention mask.\n padding_mask (torch.Tensor, optional): Padding mask.\n\n Returns:\n dec_bboxes (torch.Tensor): Decoded bounding boxes.\n dec_cls (torch.Tensor): Decoded classification scores.\n \"\"\"\n output = embed\n dec_bboxes = []\n dec_cls = []\n last_refined_bbox = None\n refer_bbox = refer_bbox.sigmoid()\n for i, layer in enumerate(self.layers):\n output = layer(output, refer_bbox, feats, shapes, padding_mask, attn_mask, pos_mlp(refer_bbox))\n\n bbox = bbox_head[i](output)\n refined_bbox = torch.sigmoid(bbox + inverse_sigmoid(refer_bbox))\n\n if self.training:\n dec_cls.append(score_head[i](output))\n if i == 0:\n dec_bboxes.append(refined_bbox)\n else:\n dec_bboxes.append(torch.sigmoid(bbox + inverse_sigmoid(last_refined_bbox)))\n elif i == self.eval_idx:\n dec_cls.append(score_head[i](output))\n dec_bboxes.append(refined_bbox)\n break\n\n last_refined_bbox = refined_bbox\n refer_bbox = refined_bbox.detach() if self.training else refined_bbox\n\n return torch.stack(dec_bboxes), torch.stack(dec_cls)", "chunk_type": "class", "name": "DeformableTransformerDecoder", "file_path": "ultralytics\\ultralytics\\nn\\modules\\transformer.py", "start_line": 714, "end_line": 802, "start_col": 0, "end_col": 60, "parent_name": null, "docstring": "Deformable Transformer Decoder based on PaddleDetection implementation.\n\nThis class implements a complete deformable transformer decoder with multiple decoder layers and prediction\nheads for bounding box regression and classification.\n\nAttributes:\n layers (nn.ModuleList): List of decoder layers.\n num_layers (int): Number of decoder layers.\n hidden_dim (int): Hidden dimension.\n eval_idx (int): Index of the layer to use during evaluation.\n\nReferences:\n https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/transformers/deformable_transformer.py", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "math", "typing.List", "typing.Optional", "torch", "torch.nn", "torch.nn.functional", "torch.nn.init.constant_", "torch.nn.init.xavier_uniform_", "conv.Conv", "utils._get_clones", "utils.inverse_sigmoid", "utils.multi_scale_deformable_attn_pytorch", "utils.torch_utils.TORCH_1_9", "nn.Module" ], "chunk_id": "class_DeformableTransformerDecoder_ac5c8729" }, { "content": "import copy", "chunk_type": "import", "name": "copy", "file_path": "ultralytics\\ultralytics\\nn\\modules\\utils.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 11, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_copy_6a0eaf33" }, { "content": "import math", "chunk_type": "import", "name": "math", "file_path": "ultralytics\\ultralytics\\nn\\modules\\utils.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 11, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_math_a6e310e4" }, { "content": "import numpy as np", "chunk_type": "import", "name": "numpy", "file_path": "ultralytics\\ultralytics\\nn\\modules\\utils.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_numpy_c5fb0e69" }, { "content": "import torch", "chunk_type": "import", "name": "torch", "file_path": "ultralytics\\ultralytics\\nn\\modules\\utils.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch_8b75e7ed" }, { "content": "import torch.nn as nn", "chunk_type": "import", "name": "torch.nn", "file_path": "ultralytics\\ultralytics\\nn\\modules\\utils.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 21, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch.nn_b05f2489" }, { "content": "import torch.nn.functional as F", "chunk_type": "import", "name": "torch.nn.functional", "file_path": "ultralytics\\ultralytics\\nn\\modules\\utils.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 31, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_torch.nn.functional_5e58c264" }, { "content": "from torch.nn.init import uniform_", "chunk_type": "import", "name": "uniform_", "file_path": "ultralytics\\ultralytics\\nn\\modules\\utils.py", "start_line": 10, "end_line": 10, "start_col": 0, "end_col": 34, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_uniform__81c9763b" }, { "content": "__all__ = \"multi_scale_deformable_attn_pytorch\", \"inverse_sigmoid\"", "chunk_type": "variable", "name": "__all__", "file_path": "ultralytics\\ultralytics\\nn\\modules\\utils.py", "start_line": 12, "end_line": 12, "start_col": 0, "end_col": 66, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable___all___85de65a2" }, { "content": "def _get_clones(module, n):\n \"\"\"\n Create a list of cloned modules from the given module.\n\n Args:\n module (nn.Module): The module to be cloned.\n n (int): Number of clones to create.\n\n Returns:\n (nn.ModuleList): A ModuleList containing n clones of the input module.\n\n Examples:\n >>> import torch.nn as nn\n >>> layer = nn.Linear(10, 10)\n >>> clones = _get_clones(layer, 3)\n >>> len(clones)\n 3\n \"\"\"\n return nn.ModuleList([copy.deepcopy(module) for _ in range(n)])", "chunk_type": "function", "name": "_get_clones", "file_path": "ultralytics\\ultralytics\\nn\\modules\\utils.py", "start_line": 15, "end_line": 33, "start_col": 0, "end_col": 67, "parent_name": null, "docstring": "Create a list of cloned modules from the given module.\n\nArgs:\n module (nn.Module): The module to be cloned.\n n (int): Number of clones to create.\n\nReturns:\n (nn.ModuleList): A ModuleList containing n clones of the input module.\n\nExamples:\n >>> import torch.nn as nn\n >>> layer = nn.Linear(10, 10)\n >>> clones = _get_clones(layer, 3)\n >>> len(clones)\n 3", "parameters": [ "module", "n" ], "return_type": null, "decorators": [], "complexity_score": 2, "dependencies": [ "copy", "math", "numpy", "torch", "torch.nn", "torch.nn.functional", "torch.nn.init.uniform_" ], "chunk_id": "function__get_clones_c4c6091d" }, { "content": "def bias_init_with_prob(prior_prob=0.01):\n \"\"\"\n Initialize conv/fc bias value according to a given probability value.\n\n This function calculates the bias initialization value based on a prior probability using the inverse error function.\n It's commonly used in object detection models to initialize classification layers with a specific positive prediction\n probability.\n\n Args:\n prior_prob (float, optional): Prior probability for bias initialization.\n\n Returns:\n (float): Bias initialization value calculated from the prior probability.\n\n Examples:\n >>> bias = bias_init_with_prob(0.01)\n >>> print(f\"Bias initialization value: {bias:.4f}\")\n Bias initialization value: -4.5951\n \"\"\"\n return float(-np.log((1 - prior_prob) / prior_prob)) # return bias_init", "chunk_type": "function", "name": "bias_init_with_prob", "file_path": "ultralytics\\ultralytics\\nn\\modules\\utils.py", "start_line": 36, "end_line": 55, "start_col": 0, "end_col": 56, "parent_name": null, "docstring": "Initialize conv/fc bias value according to a given probability value.\n\nThis function calculates the bias initialization value based on a prior probability using the inverse error function.\nIt's commonly used in object detection models to initialize classification layers with a specific positive prediction\nprobability.\n\nArgs:\n prior_prob (float, optional): Prior probability for bias initialization.\n\nReturns:\n (float): Bias initialization value calculated from the prior probability.\n\nExamples:\n >>> bias = bias_init_with_prob(0.01)\n >>> print(f\"Bias initialization value: {bias:.4f}\")\n Bias initialization value: -4.5951", "parameters": [ "prior_prob" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "copy", "math", "numpy", "torch", "torch.nn", "torch.nn.functional", "torch.nn.init.uniform_" ], "chunk_id": "function_bias_init_with_prob_e7a9d08b" }, { "content": "def linear_init(module):\n \"\"\"\n Initialize the weights and biases of a linear module.\n\n This function initializes the weights of a linear module using a uniform distribution within bounds calculated\n from the input dimension. If the module has a bias, it is also initialized.\n\n Args:\n module (nn.Module): Linear module to initialize.\n\n Returns:\n (nn.Module): The initialized module.\n\n Examples:\n >>> import torch.nn as nn\n >>> linear = nn.Linear(10, 5)\n >>> initialized_linear = linear_init(linear)\n \"\"\"\n bound = 1 / math.sqrt(module.weight.shape[0])\n uniform_(module.weight, -bound, bound)\n if hasattr(module, \"bias\") and module.bias is not None:\n uniform_(module.bias, -bound, bound)", "chunk_type": "function", "name": "linear_init", "file_path": "ultralytics\\ultralytics\\nn\\modules\\utils.py", "start_line": 58, "end_line": 79, "start_col": 0, "end_col": 44, "parent_name": null, "docstring": "Initialize the weights and biases of a linear module.\n\nThis function initializes the weights of a linear module using a uniform distribution within bounds calculated\nfrom the input dimension. If the module has a bias, it is also initialized.\n\nArgs:\n module (nn.Module): Linear module to initialize.\n\nReturns:\n (nn.Module): The initialized module.\n\nExamples:\n >>> import torch.nn as nn\n >>> linear = nn.Linear(10, 5)\n >>> initialized_linear = linear_init(linear)", "parameters": [ "module" ], "return_type": null, "decorators": [], "complexity_score": 2, "dependencies": [ "copy", "math", "numpy", "torch", "torch.nn", "torch.nn.functional", "torch.nn.init.uniform_" ], "chunk_id": "function_linear_init_1a00c166" }, { "content": "def inverse_sigmoid(x, eps=1e-5):\n \"\"\"\n Calculate the inverse sigmoid function for a tensor.\n\n This function applies the inverse of the sigmoid function to a tensor, which is useful in various neural network\n operations, particularly in attention mechanisms and coordinate transformations.\n\n Args:\n x (torch.Tensor): Input tensor with values in range [0, 1].\n eps (float, optional): Small epsilon value to prevent numerical instability.\n\n Returns:\n (torch.Tensor): Tensor after applying the inverse sigmoid function.\n\n Examples:\n >>> x = torch.tensor([0.2, 0.5, 0.8])\n >>> inverse_sigmoid(x)\n tensor([-1.3863, 0.0000, 1.3863])\n \"\"\"\n x = x.clamp(min=0, max=1)\n x1 = x.clamp(min=eps)\n x2 = (1 - x).clamp(min=eps)\n return torch.log(x1 / x2)", "chunk_type": "function", "name": "inverse_sigmoid", "file_path": "ultralytics\\ultralytics\\nn\\modules\\utils.py", "start_line": 82, "end_line": 104, "start_col": 0, "end_col": 29, "parent_name": null, "docstring": "Calculate the inverse sigmoid function for a tensor.\n\nThis function applies the inverse of the sigmoid function to a tensor, which is useful in various neural network\noperations, particularly in attention mechanisms and coordinate transformations.\n\nArgs:\n x (torch.Tensor): Input tensor with values in range [0, 1].\n eps (float, optional): Small epsilon value to prevent numerical instability.\n\nReturns:\n (torch.Tensor): Tensor after applying the inverse sigmoid function.\n\nExamples:\n >>> x = torch.tensor([0.2, 0.5, 0.8])\n >>> inverse_sigmoid(x)\n tensor([-1.3863, 0.0000, 1.3863])", "parameters": [ "x", "eps" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "copy", "math", "numpy", "torch", "torch.nn", "torch.nn.functional", "torch.nn.init.uniform_" ], "chunk_id": "function_inverse_sigmoid_9e224e18" }, { "content": "def multi_scale_deformable_attn_pytorch(\n value: torch.Tensor,\n value_spatial_shapes: torch.Tensor,\n sampling_locations: torch.Tensor,\n attention_weights: torch.Tensor,\n) -> torch.Tensor:\n \"\"\"\n Implement multi-scale deformable attention in PyTorch.\n\n This function performs deformable attention across multiple feature map scales, allowing the model to attend to\n different spatial locations with learned offsets.\n\n Args:\n value (torch.Tensor): The value tensor with shape (bs, num_keys, num_heads, embed_dims).\n value_spatial_shapes (torch.Tensor): Spatial shapes of the value tensor with shape (num_levels, 2).\n sampling_locations (torch.Tensor): The sampling locations with shape\n (bs, num_queries, num_heads, num_levels, num_points, 2).\n attention_weights (torch.Tensor): The attention weights with shape\n (bs, num_queries, num_heads, num_levels, num_points).\n\n Returns:\n (torch.Tensor): The output tensor with shape (bs, num_queries, embed_dims).\n\n References:\n https://github.com/IDEA-Research/detrex/blob/main/detrex/layers/multi_scale_deform_attn.py\n \"\"\"\n bs, _, num_heads, embed_dims = value.shape\n _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape\n value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)\n sampling_grids = 2 * sampling_locations - 1\n sampling_value_list = []\n for level, (H_, W_) in enumerate(value_spatial_shapes):\n # bs, H_*W_, num_heads, embed_dims ->\n # bs, H_*W_, num_heads*embed_dims ->\n # bs, num_heads*embed_dims, H_*W_ ->\n # bs*num_heads, embed_dims, H_, W_\n value_l_ = value_list[level].flatten(2).transpose(1, 2).reshape(bs * num_heads, embed_dims, H_, W_)\n # bs, num_queries, num_heads, num_points, 2 ->\n # bs, num_heads, num_queries, num_points, 2 ->\n # bs*num_heads, num_queries, num_points, 2\n sampling_grid_l_ = sampling_grids[:, :, :, level].transpose(1, 2).flatten(0, 1)\n # bs*num_heads, embed_dims, num_queries, num_points\n sampling_value_l_ = F.grid_sample(\n value_l_, sampling_grid_l_, mode=\"bilinear\", padding_mode=\"zeros\", align_corners=False\n )\n sampling_value_list.append(sampling_value_l_)\n # (bs, num_queries, num_heads, num_levels, num_points) ->\n # (bs, num_heads, num_queries, num_levels, num_points) ->\n # (bs, num_heads, 1, num_queries, num_levels*num_points)\n attention_weights = attention_weights.transpose(1, 2).reshape(\n bs * num_heads, 1, num_queries, num_levels * num_points\n )\n output = (\n (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)\n .sum(-1)\n .view(bs, num_heads * embed_dims, num_queries)\n )\n return output.transpose(1, 2).contiguous()", "chunk_type": "function", "name": "multi_scale_deformable_attn_pytorch", "file_path": "ultralytics\\ultralytics\\nn\\modules\\utils.py", "start_line": 107, "end_line": 164, "start_col": 0, "end_col": 46, "parent_name": null, "docstring": "Implement multi-scale deformable attention in PyTorch.\n\nThis function performs deformable attention across multiple feature map scales, allowing the model to attend to\ndifferent spatial locations with learned offsets.\n\nArgs:\n value (torch.Tensor): The value tensor with shape (bs, num_keys, num_heads, embed_dims).\n value_spatial_shapes (torch.Tensor): Spatial shapes of the value tensor with shape (num_levels, 2).\n sampling_locations (torch.Tensor): The sampling locations with shape\n (bs, num_queries, num_heads, num_levels, num_points, 2).\n attention_weights (torch.Tensor): The attention weights with shape\n (bs, num_queries, num_heads, num_levels, num_points).\n\nReturns:\n (torch.Tensor): The output tensor with shape (bs, num_queries, embed_dims).\n\nReferences:\n https://github.com/IDEA-Research/detrex/blob/main/detrex/layers/multi_scale_deform_attn.py", "parameters": [ "value: torch.Tensor", "value_spatial_shapes: torch.Tensor", "sampling_locations: torch.Tensor", "attention_weights: torch.Tensor" ], "return_type": "torch.Tensor", "decorators": [], "complexity_score": 3, "dependencies": [ "copy", "math", "numpy", "torch", "torch.nn", "torch.nn.functional", "torch.nn.init.uniform_" ], "chunk_id": "function_multi_scale_deformable_attn_pytorch_46591a4c" }, { "content": "from .block import (\n C1,\n C2,\n C2PSA,\n C3,\n C3TR,\n CIB,\n DFL,\n ELAN1,\n PSA,\n SPP,\n SPPELAN,\n SPPF,\n A2C2f,\n AConv,\n ADown,\n Attention,\n BNContrastiveHead,\n Bottleneck,\n BottleneckCSP,\n C2f,\n C2fAttn,\n C2fCIB,\n C2fPSA,\n C3Ghost,\n C3k2,\n C3x,\n CBFuse,\n CBLinear,\n ContrastiveHead,\n GhostBottleneck,\n HGBlock,\n HGStem,\n ImagePoolingAttn,\n MaxSigmoidAttnBlock,\n Proto,\n RepC3,\n RepNCSPELAN4,\n RepVGGDW,\n ResNetLayer,\n SCDown,\n TorchVision,\n)", "chunk_type": "import", "name": "C1, C2, C2PSA, C3, C3TR, CIB, DFL, ELAN1, PSA, SPP, SPPELAN, SPPF, A2C2f, AConv, ADown, Attention, BNContrastiveHead, Bottleneck, BottleneckCSP, C2f, C2fAttn, C2fCIB, C2fPSA, C3Ghost, C3k2, C3x, CBFuse, CBLinear, ContrastiveHead, GhostBottleneck, HGBlock, HGStem, ImagePoolingAttn, MaxSigmoidAttnBlock, Proto, RepC3, RepNCSPELAN4, RepVGGDW, ResNetLayer, SCDown, TorchVision", "file_path": "ultralytics\\ultralytics\\nn\\modules\\__init__.py", "start_line": 20, "end_line": 62, "start_col": 0, "end_col": 1, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_C1, C2, C2PSA, C3, C3TR, CIB, DFL, ELAN1, PSA, SPP, SPPELAN, SPPF, A2C2f, AConv, ADown, Attention, BNContrastiveHead, Bottleneck, BottleneckCSP, C2f, C2fAttn, C2fCIB, C2fPSA, C3Ghost, C3k2, C3x, CBFuse, CBLinear, ContrastiveHead, GhostBottleneck, HGBlock, HGStem, ImagePoolingAttn, MaxSigmoidAttnBlock, Proto, RepC3, RepNCSPELAN4, RepVGGDW, ResNetLayer, SCDown, TorchVision_65ef2def" }, { "content": "from .conv import (\n CBAM,\n ChannelAttention,\n Concat,\n Conv,\n Conv2,\n ConvTranspose,\n DWConv,\n DWConvTranspose2d,\n Focus,\n GhostConv,\n Index,\n LightConv,\n RepConv,\n SpatialAttention,\n)", "chunk_type": "import", "name": "CBAM, ChannelAttention, Concat, Conv, Conv2, ConvTranspose, DWConv, DWConvTranspose2d, Focus, GhostConv, Index, LightConv, RepConv, SpatialAttention", "file_path": "ultralytics\\ultralytics\\nn\\modules\\__init__.py", "start_line": 63, "end_line": 78, "start_col": 0, "end_col": 1, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_CBAM, ChannelAttention, Concat, Conv, Conv2, ConvTranspose, DWConv, DWConvTranspose2d, Focus, GhostConv, Index, LightConv, RepConv, SpatialAttention_a70cef6e" }, { "content": "from .head import (\n OBB,\n Classify,\n Detect,\n LRPCHead,\n Pose,\n RTDETRDecoder,\n Segment,\n WorldDetect,\n YOLOEDetect,\n YOLOESegment,\n v10Detect,\n)", "chunk_type": "import", "name": "OBB, Classify, Detect, LRPCHead, Pose, RTDETRDecoder, Segment, WorldDetect, YOLOEDetect, YOLOESegment, v10Detect", "file_path": "ultralytics\\ultralytics\\nn\\modules\\__init__.py", "start_line": 79, "end_line": 91, "start_col": 0, "end_col": 1, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_OBB, Classify, Detect, LRPCHead, Pose, RTDETRDecoder, Segment, WorldDetect, YOLOEDetect, YOLOESegment, v10Detect_35091f91" }, { "content": "from .transformer import (\n AIFI,\n MLP,\n DeformableTransformerDecoder,\n DeformableTransformerDecoderLayer,\n LayerNorm2d,\n MLPBlock,\n MSDeformAttn,\n TransformerBlock,\n TransformerEncoderLayer,\n TransformerLayer,\n)", "chunk_type": "import", "name": "AIFI, MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer, LayerNorm2d, MLPBlock, MSDeformAttn, TransformerBlock, TransformerEncoderLayer, TransformerLayer", "file_path": "ultralytics\\ultralytics\\nn\\modules\\__init__.py", "start_line": 92, "end_line": 103, "start_col": 0, "end_col": 1, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_AIFI, MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer, LayerNorm2d, MLPBlock, MSDeformAttn, TransformerBlock, TransformerEncoderLayer, TransformerLayer_c27e4da3" }, { "content": "__all__ = (\n \"Conv\",\n \"Conv2\",\n \"LightConv\",\n \"RepConv\",\n \"DWConv\",\n \"DWConvTranspose2d\",\n \"ConvTranspose\",\n \"Focus\",\n \"GhostConv\",\n \"ChannelAttention\",\n \"SpatialAttention\",\n \"CBAM\",\n \"Concat\",\n \"TransformerLayer\",\n \"TransformerBlock\",\n \"MLPBlock\",\n \"LayerNorm2d\",\n \"DFL\",\n \"HGBlock\",\n \"HGStem\",\n \"SPP\",\n \"SPPF\",\n \"C1\",\n \"C2\",\n \"C3\",\n \"C2f\",\n \"C3k2\",\n \"SCDown\",\n \"C2fPSA\",\n \"C2PSA\",\n \"C2fAttn\",\n \"C3x\",\n \"C3TR\",\n \"C3Ghost\",\n \"GhostBottleneck\",\n \"Bottleneck\",\n \"BottleneckCSP\",\n \"Proto\",\n \"Detect\",\n \"Segment\",\n \"Pose\",\n \"Classify\",\n \"TransformerEncoderLayer\",\n \"RepC3\",\n \"RTDETRDecoder\",\n \"AIFI\",\n \"DeformableTransformerDecoder\",\n \"DeformableTransformerDecoderLayer\",\n \"MSDeformAttn\",\n \"MLP\",\n \"ResNetLayer\",\n \"OBB\",\n \"WorldDetect\",\n \"YOLOEDetect\",\n \"YOLOESegment\",\n \"v10Detect\",\n \"LRPCHead\",\n \"ImagePoolingAttn\",\n \"MaxSigmoidAttnBlock\",\n \"ContrastiveHead\",\n \"BNContrastiveHead\",\n \"RepNCSPELAN4\",\n \"ADown\",\n \"SPPELAN\",\n \"CBFuse\",\n \"CBLinear\",\n \"AConv\",\n \"ELAN1\",\n \"RepVGGDW\",\n \"CIB\",\n \"C2fCIB\",\n \"Attention\",\n \"PSA\",\n \"TorchVision\",\n \"Index\",\n \"A2C2f\",\n)", "chunk_type": "variable", "name": "__all__", "file_path": "ultralytics\\ultralytics\\nn\\modules\\__init__.py", "start_line": 105, "end_line": 182, "start_col": 0, "end_col": 1, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable___all___b3080452" }, { "content": "import copy", "chunk_type": "import", "name": "copy", "file_path": "ultralytics\\ultralytics\\trackers\\utils\\gmc.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 11, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_copy_6bb5f9c9" }, { "content": "from typing import List, Optional", "chunk_type": "import", "name": "List, Optional", "file_path": "ultralytics\\ultralytics\\trackers\\utils\\gmc.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 33, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_List, Optional_845e34bc" }, { "content": "import cv2", "chunk_type": "import", "name": "cv2", "file_path": "ultralytics\\ultralytics\\trackers\\utils\\gmc.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 10, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_cv2_a3836a3f" }, { "content": "import numpy as np", "chunk_type": "import", "name": "numpy", "file_path": "ultralytics\\ultralytics\\trackers\\utils\\gmc.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_numpy_c9f60795" }, { "content": "from ultralytics.utils import LOGGER", "chunk_type": "import", "name": "LOGGER", "file_path": "ultralytics\\ultralytics\\trackers\\utils\\gmc.py", "start_line": 9, "end_line": 9, "start_col": 0, "end_col": 36, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_LOGGER_40dda891" }, { "content": "class GMC:\n \"\"\"\n Generalized Motion Compensation (GMC) class for tracking and object detection in video frames.\n\n This class provides methods for tracking and detecting objects based on several tracking algorithms including ORB,\n SIFT, ECC, and Sparse Optical Flow. It also supports downscaling of frames for computational efficiency.\n\n Attributes:\n method (str): The tracking method to use. Options include 'orb', 'sift', 'ecc', 'sparseOptFlow', 'none'.\n downscale (int): Factor by which to downscale the frames for processing.\n prevFrame (np.ndarray): Previous frame for tracking.\n prevKeyPoints (List): Keypoints from the previous frame.\n prevDescriptors (np.ndarray): Descriptors from the previous frame.\n initializedFirstFrame (bool): Flag indicating if the first frame has been processed.\n\n Methods:\n apply: Apply the chosen method to a raw frame and optionally use provided detections.\n apply_ecc: Apply the ECC algorithm to a raw frame.\n apply_features: Apply feature-based methods like ORB or SIFT to a raw frame.\n apply_sparseoptflow: Apply the Sparse Optical Flow method to a raw frame.\n reset_params: Reset the internal parameters of the GMC object.\n\n Examples:\n Create a GMC object and apply it to a frame\n >>> gmc = GMC(method=\"sparseOptFlow\", downscale=2)\n >>> frame = np.array([[1, 2, 3], [4, 5, 6]])\n >>> processed_frame = gmc.apply(frame)\n >>> print(processed_frame)\n array([[1, 2, 3],\n [4, 5, 6]])\n \"\"\"\n\n def __init__(self, method: str = \"sparseOptFlow\", downscale: int = 2) -> None:\n \"\"\"\n Initialize a Generalized Motion Compensation (GMC) object with tracking method and downscale factor.\n\n Args:\n method (str): The tracking method to use. Options include 'orb', 'sift', 'ecc', 'sparseOptFlow', 'none'.\n downscale (int): Downscale factor for processing frames.\n\n Examples:\n Initialize a GMC object with the 'sparseOptFlow' method and a downscale factor of 2\n >>> gmc = GMC(method=\"sparseOptFlow\", downscale=2)\n \"\"\"\n super().__init__()\n\n self.method = method\n self.downscale = max(1, downscale)\n\n if self.method == \"orb\":\n self.detector = cv2.FastFeatureDetector_create(20)\n self.extractor = cv2.ORB_create()\n self.matcher = cv2.BFMatcher(cv2.NORM_HAMMING)\n\n elif self.method == \"sift\":\n self.detector = cv2.SIFT_create(nOctaveLayers=3, contrastThreshold=0.02, edgeThreshold=20)\n self.extractor = cv2.SIFT_create(nOctaveLayers=3, contrastThreshold=0.02, edgeThreshold=20)\n self.matcher = cv2.BFMatcher(cv2.NORM_L2)\n\n elif self.method == \"ecc\":\n number_of_iterations = 5000\n termination_eps = 1e-6\n self.warp_mode = cv2.MOTION_EUCLIDEAN\n self.criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, number_of_iterations, termination_eps)\n\n elif self.method == \"sparseOptFlow\":\n self.feature_params = dict(\n maxCorners=1000, qualityLevel=0.01, minDistance=1, blockSize=3, useHarrisDetector=False, k=0.04\n )\n\n elif self.method in {\"none\", \"None\", None}:\n self.method = None\n else:\n raise ValueError(f\"Unknown GMC method: {method}\")\n\n self.prevFrame = None\n self.prevKeyPoints = None\n self.prevDescriptors = None\n self.initializedFirstFrame = False\n\n def apply(self, raw_frame: np.ndarray, detections: Optional[List] = None) -> np.ndarray:\n \"\"\"\n Apply object detection on a raw frame using the specified method.\n\n Args:\n raw_frame (np.ndarray): The raw frame to be processed, with shape (H, W, C).\n detections (List, optional): List of detections to be used in the processing.\n\n Returns:\n (np.ndarray): Transformation matrix with shape (2, 3).\n\n Examples:\n >>> gmc = GMC(method=\"sparseOptFlow\")\n >>> raw_frame = np.random.rand(480, 640, 3)\n >>> transformation_matrix = gmc.apply(raw_frame)\n >>> print(transformation_matrix.shape)\n (2, 3)\n \"\"\"\n if self.method in {\"orb\", \"sift\"}:\n return self.apply_features(raw_frame, detections)\n elif self.method == \"ecc\":\n return self.apply_ecc(raw_frame)\n elif self.method == \"sparseOptFlow\":\n return self.apply_sparseoptflow(raw_frame)\n else:\n return np.eye(2, 3)\n\n def apply_ecc(self, raw_frame: np.ndarray) -> np.ndarray:\n \"\"\"\n Apply the ECC (Enhanced Correlation Coefficient) algorithm to a raw frame for motion compensation.\n\n Args:\n raw_frame (np.ndarray): The raw frame to be processed, with shape (H, W, C).\n\n Returns:\n (np.ndarray): Transformation matrix with shape (2, 3).\n\n Examples:\n >>> gmc = GMC(method=\"ecc\")\n >>> processed_frame = gmc.apply_ecc(np.array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]))\n >>> print(processed_frame)\n [[1. 0. 0.]\n [0. 1. 0.]]\n \"\"\"\n height, width, c = raw_frame.shape\n frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY) if c == 3 else raw_frame\n H = np.eye(2, 3, dtype=np.float32)\n\n # Downscale image for computational efficiency\n if self.downscale > 1.0:\n frame = cv2.GaussianBlur(frame, (3, 3), 1.5)\n frame = cv2.resize(frame, (width // self.downscale, height // self.downscale))\n\n # Handle first frame initialization\n if not self.initializedFirstFrame:\n self.prevFrame = frame.copy()\n self.initializedFirstFrame = True\n return H\n\n # Run the ECC algorithm to find transformation matrix\n try:\n (_, H) = cv2.findTransformECC(self.prevFrame, frame, H, self.warp_mode, self.criteria, None, 1)\n except Exception as e:\n LOGGER.warning(f\"find transform failed. Set warp as identity {e}\")\n\n return H\n\n def apply_features(self, raw_frame: np.ndarray, detections: Optional[List] = None) -> np.ndarray:\n \"\"\"\n Apply feature-based methods like ORB or SIFT to a raw frame.\n\n Args:\n raw_frame (np.ndarray): The raw frame to be processed, with shape (H, W, C).\n detections (List, optional): List of detections to be used in the processing.\n\n Returns:\n (np.ndarray): Transformation matrix with shape (2, 3).\n\n Examples:\n >>> gmc = GMC(method=\"orb\")\n >>> raw_frame = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8)\n >>> transformation_matrix = gmc.apply_features(raw_frame)\n >>> print(transformation_matrix.shape)\n (2, 3)\n \"\"\"\n height, width, c = raw_frame.shape\n frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY) if c == 3 else raw_frame\n H = np.eye(2, 3)\n\n # Downscale image for computational efficiency\n if self.downscale > 1.0:\n frame = cv2.resize(frame, (width // self.downscale, height // self.downscale))\n width = width // self.downscale\n height = height // self.downscale\n\n # Create mask for keypoint detection, excluding border regions\n mask = np.zeros_like(frame)\n mask[int(0.02 * height) : int(0.98 * height), int(0.02 * width) : int(0.98 * width)] = 255\n\n # Exclude detection regions from mask to avoid tracking detected objects\n if detections is not None:\n for det in detections:\n tlbr = (det[:4] / self.downscale).astype(np.int_)\n mask[tlbr[1] : tlbr[3], tlbr[0] : tlbr[2]] = 0\n\n # Find keypoints and compute descriptors\n keypoints = self.detector.detect(frame, mask)\n keypoints, descriptors = self.extractor.compute(frame, keypoints)\n\n # Handle first frame initialization\n if not self.initializedFirstFrame:\n self.prevFrame = frame.copy()\n self.prevKeyPoints = copy.copy(keypoints)\n self.prevDescriptors = copy.copy(descriptors)\n self.initializedFirstFrame = True\n return H\n\n # Match descriptors between previous and current frame\n knnMatches = self.matcher.knnMatch(self.prevDescriptors, descriptors, 2)\n\n # Filter matches based on spatial distance constraints\n matches = []\n spatialDistances = []\n maxSpatialDistance = 0.25 * np.array([width, height])\n\n # Handle empty matches case\n if len(knnMatches) == 0:\n self.prevFrame = frame.copy()\n self.prevKeyPoints = copy.copy(keypoints)\n self.prevDescriptors = copy.copy(descriptors)\n return H\n\n # Apply Lowe's ratio test and spatial distance filtering\n for m, n in knnMatches:\n if m.distance < 0.9 * n.distance:\n prevKeyPointLocation = self.prevKeyPoints[m.queryIdx].pt\n currKeyPointLocation = keypoints[m.trainIdx].pt\n\n spatialDistance = (\n prevKeyPointLocation[0] - currKeyPointLocation[0],\n prevKeyPointLocation[1] - currKeyPointLocation[1],\n )\n\n if (np.abs(spatialDistance[0]) < maxSpatialDistance[0]) and (\n np.abs(spatialDistance[1]) < maxSpatialDistance[1]\n ):\n spatialDistances.append(spatialDistance)\n matches.append(m)\n\n # Filter outliers using statistical analysis\n meanSpatialDistances = np.mean(spatialDistances, 0)\n stdSpatialDistances = np.std(spatialDistances, 0)\n inliers = (spatialDistances - meanSpatialDistances) < 2.5 * stdSpatialDistances\n\n # Extract good matches and corresponding points\n goodMatches = []\n prevPoints = []\n currPoints = []\n for i in range(len(matches)):\n if inliers[i, 0] and inliers[i, 1]:\n goodMatches.append(matches[i])\n prevPoints.append(self.prevKeyPoints[matches[i].queryIdx].pt)\n currPoints.append(keypoints[matches[i].trainIdx].pt)\n\n prevPoints = np.array(prevPoints)\n currPoints = np.array(currPoints)\n\n # Estimate transformation matrix using RANSAC\n if prevPoints.shape[0] > 4:\n H, inliers = cv2.estimateAffinePartial2D(prevPoints, currPoints, cv2.RANSAC)\n\n # Scale translation components back to original resolution\n if self.downscale > 1.0:\n H[0, 2] *= self.downscale\n H[1, 2] *= self.downscale\n else:\n LOGGER.warning(\"not enough matching points\")\n\n # Store current frame data for next iteration\n self.prevFrame = frame.copy()\n self.prevKeyPoints = copy.copy(keypoints)\n self.prevDescriptors = copy.copy(descriptors)\n\n return H\n\n def apply_sparseoptflow(self, raw_frame: np.ndarray) -> np.ndarray:\n \"\"\"\n Apply Sparse Optical Flow method to a raw frame.\n\n Args:\n raw_frame (np.ndarray): The raw frame to be processed, with shape (H, W, C).\n\n Returns:\n (np.ndarray): Transformation matrix with shape (2, 3).\n\n Examples:\n >>> gmc = GMC()\n >>> result = gmc.apply_sparseoptflow(np.array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]))\n >>> print(result)\n [[1. 0. 0.]\n [0. 1. 0.]]\n \"\"\"\n height, width, c = raw_frame.shape\n frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY) if c == 3 else raw_frame\n H = np.eye(2, 3)\n\n # Downscale image for computational efficiency\n if self.downscale > 1.0:\n frame = cv2.resize(frame, (width // self.downscale, height // self.downscale))\n\n # Find good features to track\n keypoints = cv2.goodFeaturesToTrack(frame, mask=None, **self.feature_params)\n\n # Handle first frame initialization\n if not self.initializedFirstFrame or self.prevKeyPoints is None:\n self.prevFrame = frame.copy()\n self.prevKeyPoints = copy.copy(keypoints)\n self.initializedFirstFrame = True\n return H\n\n # Calculate optical flow using Lucas-Kanade method\n matchedKeypoints, status, _ = cv2.calcOpticalFlowPyrLK(self.prevFrame, frame, self.prevKeyPoints, None)\n\n # Extract successfully tracked points\n prevPoints = []\n currPoints = []\n\n for i in range(len(status)):\n if status[i]:\n prevPoints.append(self.prevKeyPoints[i])\n currPoints.append(matchedKeypoints[i])\n\n prevPoints = np.array(prevPoints)\n currPoints = np.array(currPoints)\n\n # Estimate transformation matrix using RANSAC\n if (prevPoints.shape[0] > 4) and (prevPoints.shape[0] == currPoints.shape[0]):\n H, _ = cv2.estimateAffinePartial2D(prevPoints, currPoints, cv2.RANSAC)\n\n # Scale translation components back to original resolution\n if self.downscale > 1.0:\n H[0, 2] *= self.downscale\n H[1, 2] *= self.downscale\n else:\n LOGGER.warning(\"not enough matching points\")\n\n # Store current frame data for next iteration\n self.prevFrame = frame.copy()\n self.prevKeyPoints = copy.copy(keypoints)\n\n return H\n\n def reset_params(self) -> None:\n \"\"\"Reset the internal parameters including previous frame, keypoints, and descriptors.\"\"\"\n self.prevFrame = None\n self.prevKeyPoints = None\n self.prevDescriptors = None\n self.initializedFirstFrame = False", "chunk_type": "class", "name": "GMC", "file_path": "ultralytics\\ultralytics\\trackers\\utils\\gmc.py", "start_line": 12, "end_line": 349, "start_col": 0, "end_col": 42, "parent_name": null, "docstring": "Generalized Motion Compensation (GMC) class for tracking and object detection in video frames.\n\nThis class provides methods for tracking and detecting objects based on several tracking algorithms including ORB,\nSIFT, ECC, and Sparse Optical Flow. It also supports downscaling of frames for computational efficiency.\n\nAttributes:\n method (str): The tracking method to use. Options include 'orb', 'sift', 'ecc', 'sparseOptFlow', 'none'.\n downscale (int): Factor by which to downscale the frames for processing.\n prevFrame (np.ndarray): Previous frame for tracking.\n prevKeyPoints (List): Keypoints from the previous frame.\n prevDescriptors (np.ndarray): Descriptors from the previous frame.\n initializedFirstFrame (bool): Flag indicating if the first frame has been processed.\n\nMethods:\n apply: Apply the chosen method to a raw frame and optionally use provided detections.\n apply_ecc: Apply the ECC algorithm to a raw frame.\n apply_features: Apply feature-based methods like ORB or SIFT to a raw frame.\n apply_sparseoptflow: Apply the Sparse Optical Flow method to a raw frame.\n reset_params: Reset the internal parameters of the GMC object.\n\nExamples:\n Create a GMC object and apply it to a frame\n >>> gmc = GMC(method=\"sparseOptFlow\", downscale=2)\n >>> frame = np.array([[1, 2, 3], [4, 5, 6]])\n >>> processed_frame = gmc.apply(frame)\n >>> print(processed_frame)\n array([[1, 2, 3],\n [4, 5, 6]])", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "copy", "typing.List", "typing.Optional", "cv2", "numpy", "ultralytics.utils.LOGGER" ], "chunk_id": "class_GMC_ca52d141" }, { "content": "import numpy as np", "chunk_type": "import", "name": "numpy", "file_path": "ultralytics\\ultralytics\\trackers\\utils\\kalman_filter.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_numpy_ea19bbd7" }, { "content": "import scipy.linalg", "chunk_type": "import", "name": "scipy.linalg", "file_path": "ultralytics\\ultralytics\\trackers\\utils\\kalman_filter.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 19, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_scipy.linalg_f50fd77e" }, { "content": "class KalmanFilterXYAH:\n \"\"\"\n A KalmanFilterXYAH class for tracking bounding boxes in image space using a Kalman filter.\n\n Implements a simple Kalman filter for tracking bounding boxes in image space. The 8-dimensional state space\n (x, y, a, h, vx, vy, va, vh) contains the bounding box center position (x, y), aspect ratio a, height h, and their\n respective velocities. Object motion follows a constant velocity model, and bounding box location (x, y, a, h) is\n taken as a direct observation of the state space (linear observation model).\n\n Attributes:\n _motion_mat (np.ndarray): The motion matrix for the Kalman filter.\n _update_mat (np.ndarray): The update matrix for the Kalman filter.\n _std_weight_position (float): Standard deviation weight for position.\n _std_weight_velocity (float): Standard deviation weight for velocity.\n\n Methods:\n initiate: Create a track from an unassociated measurement.\n predict: Run the Kalman filter prediction step.\n project: Project the state distribution to measurement space.\n multi_predict: Run the Kalman filter prediction step (vectorized version).\n update: Run the Kalman filter correction step.\n gating_distance: Compute the gating distance between state distribution and measurements.\n\n Examples:\n Initialize the Kalman filter and create a track from a measurement\n >>> kf = KalmanFilterXYAH()\n >>> measurement = np.array([100, 200, 1.5, 50])\n >>> mean, covariance = kf.initiate(measurement)\n >>> print(mean)\n >>> print(covariance)\n \"\"\"\n\n def __init__(self):\n \"\"\"\n Initialize Kalman filter model matrices with motion and observation uncertainty weights.\n\n The Kalman filter is initialized with an 8-dimensional state space (x, y, a, h, vx, vy, va, vh), where (x, y)\n represents the bounding box center position, 'a' is the aspect ratio, 'h' is the height, and their respective\n velocities are (vx, vy, va, vh). The filter uses a constant velocity model for object motion and a linear\n observation model for bounding box location.\n\n Examples:\n Initialize a Kalman filter for tracking:\n >>> kf = KalmanFilterXYAH()\n \"\"\"\n ndim, dt = 4, 1.0\n\n # Create Kalman filter model matrices\n self._motion_mat = np.eye(2 * ndim, 2 * ndim)\n for i in range(ndim):\n self._motion_mat[i, ndim + i] = dt\n self._update_mat = np.eye(ndim, 2 * ndim)\n\n # Motion and observation uncertainty are chosen relative to the current state estimate\n self._std_weight_position = 1.0 / 20\n self._std_weight_velocity = 1.0 / 160\n\n def initiate(self, measurement: np.ndarray):\n \"\"\"\n Create a track from an unassociated measurement.\n\n Args:\n measurement (np.ndarray): Bounding box coordinates (x, y, a, h) with center position (x, y), aspect ratio a,\n and height h.\n\n Returns:\n mean (np.ndarray): Mean vector (8-dimensional) of the new track. Unobserved velocities are initialized to 0 mean.\n covariance (np.ndarray): Covariance matrix (8x8 dimensional) of the new track.\n\n Examples:\n >>> kf = KalmanFilterXYAH()\n >>> measurement = np.array([100, 50, 1.5, 200])\n >>> mean, covariance = kf.initiate(measurement)\n \"\"\"\n mean_pos = measurement\n mean_vel = np.zeros_like(mean_pos)\n mean = np.r_[mean_pos, mean_vel]\n\n std = [\n 2 * self._std_weight_position * measurement[3],\n 2 * self._std_weight_position * measurement[3],\n 1e-2,\n 2 * self._std_weight_position * measurement[3],\n 10 * self._std_weight_velocity * measurement[3],\n 10 * self._std_weight_velocity * measurement[3],\n 1e-5,\n 10 * self._std_weight_velocity * measurement[3],\n ]\n covariance = np.diag(np.square(std))\n return mean, covariance\n\n def predict(self, mean: np.ndarray, covariance: np.ndarray):\n \"\"\"\n Run Kalman filter prediction step.\n\n Args:\n mean (np.ndarray): The 8-dimensional mean vector of the object state at the previous time step.\n covariance (np.ndarray): The 8x8-dimensional covariance matrix of the object state at the previous time step.\n\n Returns:\n mean (np.ndarray): Mean vector of the predicted state. Unobserved velocities are initialized to 0 mean.\n covariance (np.ndarray): Covariance matrix of the predicted state.\n\n Examples:\n >>> kf = KalmanFilterXYAH()\n >>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0])\n >>> covariance = np.eye(8)\n >>> predicted_mean, predicted_covariance = kf.predict(mean, covariance)\n \"\"\"\n std_pos = [\n self._std_weight_position * mean[3],\n self._std_weight_position * mean[3],\n 1e-2,\n self._std_weight_position * mean[3],\n ]\n std_vel = [\n self._std_weight_velocity * mean[3],\n self._std_weight_velocity * mean[3],\n 1e-5,\n self._std_weight_velocity * mean[3],\n ]\n motion_cov = np.diag(np.square(np.r_[std_pos, std_vel]))\n\n mean = np.dot(mean, self._motion_mat.T)\n covariance = np.linalg.multi_dot((self._motion_mat, covariance, self._motion_mat.T)) + motion_cov\n\n return mean, covariance\n\n def project(self, mean: np.ndarray, covariance: np.ndarray):\n \"\"\"\n Project state distribution to measurement space.\n\n Args:\n mean (np.ndarray): The state's mean vector (8 dimensional array).\n covariance (np.ndarray): The state's covariance matrix (8x8 dimensional).\n\n Returns:\n mean (np.ndarray): Projected mean of the given state estimate.\n covariance (np.ndarray): Projected covariance matrix of the given state estimate.\n\n Examples:\n >>> kf = KalmanFilterXYAH()\n >>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0])\n >>> covariance = np.eye(8)\n >>> projected_mean, projected_covariance = kf.project(mean, covariance)\n \"\"\"\n std = [\n self._std_weight_position * mean[3],\n self._std_weight_position * mean[3],\n 1e-1,\n self._std_weight_position * mean[3],\n ]\n innovation_cov = np.diag(np.square(std))\n\n mean = np.dot(self._update_mat, mean)\n covariance = np.linalg.multi_dot((self._update_mat, covariance, self._update_mat.T))\n return mean, covariance + innovation_cov\n\n def multi_predict(self, mean: np.ndarray, covariance: np.ndarray):\n \"\"\"\n Run Kalman filter prediction step for multiple object states (Vectorized version).\n\n Args:\n mean (np.ndarray): The Nx8 dimensional mean matrix of the object states at the previous time step.\n covariance (np.ndarray): The Nx8x8 covariance matrix of the object states at the previous time step.\n\n Returns:\n mean (np.ndarray): Mean matrix of the predicted states with shape (N, 8).\n covariance (np.ndarray): Covariance matrix of the predicted states with shape (N, 8, 8).\n\n Examples:\n >>> mean = np.random.rand(10, 8) # 10 object states\n >>> covariance = np.random.rand(10, 8, 8) # Covariance matrices for 10 object states\n >>> predicted_mean, predicted_covariance = kalman_filter.multi_predict(mean, covariance)\n \"\"\"\n std_pos = [\n self._std_weight_position * mean[:, 3],\n self._std_weight_position * mean[:, 3],\n 1e-2 * np.ones_like(mean[:, 3]),\n self._std_weight_position * mean[:, 3],\n ]\n std_vel = [\n self._std_weight_velocity * mean[:, 3],\n self._std_weight_velocity * mean[:, 3],\n 1e-5 * np.ones_like(mean[:, 3]),\n self._std_weight_velocity * mean[:, 3],\n ]\n sqr = np.square(np.r_[std_pos, std_vel]).T\n\n motion_cov = [np.diag(sqr[i]) for i in range(len(mean))]\n motion_cov = np.asarray(motion_cov)\n\n mean = np.dot(mean, self._motion_mat.T)\n left = np.dot(self._motion_mat, covariance).transpose((1, 0, 2))\n covariance = np.dot(left, self._motion_mat.T) + motion_cov\n\n return mean, covariance\n\n def update(self, mean: np.ndarray, covariance: np.ndarray, measurement: np.ndarray):\n \"\"\"\n Run Kalman filter correction step.\n\n Args:\n mean (np.ndarray): The predicted state's mean vector (8 dimensional).\n covariance (np.ndarray): The state's covariance matrix (8x8 dimensional).\n measurement (np.ndarray): The 4 dimensional measurement vector (x, y, a, h), where (x, y) is the center\n position, a the aspect ratio, and h the height of the bounding box.\n\n Returns:\n new_mean (np.ndarray): Measurement-corrected state mean.\n new_covariance (np.ndarray): Measurement-corrected state covariance.\n\n Examples:\n >>> kf = KalmanFilterXYAH()\n >>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0])\n >>> covariance = np.eye(8)\n >>> measurement = np.array([1, 1, 1, 1])\n >>> new_mean, new_covariance = kf.update(mean, covariance, measurement)\n \"\"\"\n projected_mean, projected_cov = self.project(mean, covariance)\n\n chol_factor, lower = scipy.linalg.cho_factor(projected_cov, lower=True, check_finite=False)\n kalman_gain = scipy.linalg.cho_solve(\n (chol_factor, lower), np.dot(covariance, self._update_mat.T).T, check_finite=False\n ).T\n innovation = measurement - projected_mean\n\n new_mean = mean + np.dot(innovation, kalman_gain.T)\n new_covariance = covariance - np.linalg.multi_dot((kalman_gain, projected_cov, kalman_gain.T))\n return new_mean, new_covariance\n\n def gating_distance(\n self,\n mean: np.ndarray,\n covariance: np.ndarray,\n measurements: np.ndarray,\n only_position: bool = False,\n metric: str = \"maha\",\n ) -> np.ndarray:\n \"\"\"\n Compute gating distance between state distribution and measurements.\n\n A suitable distance threshold can be obtained from `chi2inv95`. If `only_position` is False, the chi-square\n distribution has 4 degrees of freedom, otherwise 2.\n\n Args:\n mean (np.ndarray): Mean vector over the state distribution (8 dimensional).\n covariance (np.ndarray): Covariance of the state distribution (8x8 dimensional).\n measurements (np.ndarray): An (N, 4) matrix of N measurements, each in format (x, y, a, h) where (x, y) is the\n bounding box center position, a the aspect ratio, and h the height.\n only_position (bool, optional): If True, distance computation is done with respect to box center position only.\n metric (str, optional): The metric to use for calculating the distance. Options are 'gaussian' for the squared\n Euclidean distance and 'maha' for the squared Mahalanobis distance.\n\n Returns:\n (np.ndarray): Returns an array of length N, where the i-th element contains the squared distance between\n (mean, covariance) and `measurements[i]`.\n\n Examples:\n Compute gating distance using Mahalanobis metric:\n >>> kf = KalmanFilterXYAH()\n >>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0])\n >>> covariance = np.eye(8)\n >>> measurements = np.array([[1, 1, 1, 1], [2, 2, 1, 1]])\n >>> distances = kf.gating_distance(mean, covariance, measurements, only_position=False, metric=\"maha\")\n \"\"\"\n mean, covariance = self.project(mean, covariance)\n if only_position:\n mean, covariance = mean[:2], covariance[:2, :2]\n measurements = measurements[:, :2]\n\n d = measurements - mean\n if metric == \"gaussian\":\n return np.sum(d * d, axis=1)\n elif metric == \"maha\":\n cholesky_factor = np.linalg.cholesky(covariance)\n z = scipy.linalg.solve_triangular(cholesky_factor, d.T, lower=True, check_finite=False, overwrite_b=True)\n return np.sum(z * z, axis=0) # square maha\n else:\n raise ValueError(\"Invalid distance metric\")", "chunk_type": "class", "name": "KalmanFilterXYAH", "file_path": "ultralytics\\ultralytics\\trackers\\utils\\kalman_filter.py", "start_line": 7, "end_line": 286, "start_col": 0, "end_col": 55, "parent_name": null, "docstring": "A KalmanFilterXYAH class for tracking bounding boxes in image space using a Kalman filter.\n\nImplements a simple Kalman filter for tracking bounding boxes in image space. The 8-dimensional state space\n(x, y, a, h, vx, vy, va, vh) contains the bounding box center position (x, y), aspect ratio a, height h, and their\nrespective velocities. Object motion follows a constant velocity model, and bounding box location (x, y, a, h) is\ntaken as a direct observation of the state space (linear observation model).\n\nAttributes:\n _motion_mat (np.ndarray): The motion matrix for the Kalman filter.\n _update_mat (np.ndarray): The update matrix for the Kalman filter.\n _std_weight_position (float): Standard deviation weight for position.\n _std_weight_velocity (float): Standard deviation weight for velocity.\n\nMethods:\n initiate: Create a track from an unassociated measurement.\n predict: Run the Kalman filter prediction step.\n project: Project the state distribution to measurement space.\n multi_predict: Run the Kalman filter prediction step (vectorized version).\n update: Run the Kalman filter correction step.\n gating_distance: Compute the gating distance between state distribution and measurements.\n\nExamples:\n Initialize the Kalman filter and create a track from a measurement\n >>> kf = KalmanFilterXYAH()\n >>> measurement = np.array([100, 200, 1.5, 50])\n >>> mean, covariance = kf.initiate(measurement)\n >>> print(mean)\n >>> print(covariance)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "numpy", "scipy.linalg" ], "chunk_id": "class_KalmanFilterXYAH_03d35260" }, { "content": "class KalmanFilterXYWH(KalmanFilterXYAH):\n \"\"\"\n A KalmanFilterXYWH class for tracking bounding boxes in image space using a Kalman filter.\n\n Implements a Kalman filter for tracking bounding boxes with state space (x, y, w, h, vx, vy, vw, vh), where\n (x, y) is the center position, w is the width, h is the height, and vx, vy, vw, vh are their respective velocities.\n The object motion follows a constant velocity model, and the bounding box location (x, y, w, h) is taken as a direct\n observation of the state space (linear observation model).\n\n Attributes:\n _motion_mat (np.ndarray): The motion matrix for the Kalman filter.\n _update_mat (np.ndarray): The update matrix for the Kalman filter.\n _std_weight_position (float): Standard deviation weight for position.\n _std_weight_velocity (float): Standard deviation weight for velocity.\n\n Methods:\n initiate: Create a track from an unassociated measurement.\n predict: Run the Kalman filter prediction step.\n project: Project the state distribution to measurement space.\n multi_predict: Run the Kalman filter prediction step in a vectorized manner.\n update: Run the Kalman filter correction step.\n\n Examples:\n Create a Kalman filter and initialize a track\n >>> kf = KalmanFilterXYWH()\n >>> measurement = np.array([100, 50, 20, 40])\n >>> mean, covariance = kf.initiate(measurement)\n >>> print(mean)\n >>> print(covariance)\n \"\"\"\n\n def initiate(self, measurement: np.ndarray):\n \"\"\"\n Create track from unassociated measurement.\n\n Args:\n measurement (np.ndarray): Bounding box coordinates (x, y, w, h) with center position (x, y), width, and height.\n\n Returns:\n mean (np.ndarray): Mean vector (8 dimensional) of the new track. Unobserved velocities are initialized to 0 mean.\n covariance (np.ndarray): Covariance matrix (8x8 dimensional) of the new track.\n\n Examples:\n >>> kf = KalmanFilterXYWH()\n >>> measurement = np.array([100, 50, 20, 40])\n >>> mean, covariance = kf.initiate(measurement)\n >>> print(mean)\n [100. 50. 20. 40. 0. 0. 0. 0.]\n >>> print(covariance)\n [[ 4. 0. 0. 0. 0. 0. 0. 0.]\n [ 0. 4. 0. 0. 0. 0. 0. 0.]\n [ 0. 0. 4. 0. 0. 0. 0. 0.]\n [ 0. 0. 0. 4. 0. 0. 0. 0.]\n [ 0. 0. 0. 0. 0.25 0. 0. 0.]\n [ 0. 0. 0. 0. 0. 0.25 0. 0.]\n [ 0. 0. 0. 0. 0. 0. 0.25 0.]\n [ 0. 0. 0. 0. 0. 0. 0. 0.25]]\n \"\"\"\n mean_pos = measurement\n mean_vel = np.zeros_like(mean_pos)\n mean = np.r_[mean_pos, mean_vel]\n\n std = [\n 2 * self._std_weight_position * measurement[2],\n 2 * self._std_weight_position * measurement[3],\n 2 * self._std_weight_position * measurement[2],\n 2 * self._std_weight_position * measurement[3],\n 10 * self._std_weight_velocity * measurement[2],\n 10 * self._std_weight_velocity * measurement[3],\n 10 * self._std_weight_velocity * measurement[2],\n 10 * self._std_weight_velocity * measurement[3],\n ]\n covariance = np.diag(np.square(std))\n return mean, covariance\n\n def predict(self, mean: np.ndarray, covariance: np.ndarray):\n \"\"\"\n Run Kalman filter prediction step.\n\n Args:\n mean (np.ndarray): The 8-dimensional mean vector of the object state at the previous time step.\n covariance (np.ndarray): The 8x8-dimensional covariance matrix of the object state at the previous time step.\n\n Returns:\n mean (np.ndarray): Mean vector of the predicted state. Unobserved velocities are initialized to 0 mean.\n covariance (np.ndarray): Covariance matrix of the predicted state.\n\n Examples:\n >>> kf = KalmanFilterXYWH()\n >>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0])\n >>> covariance = np.eye(8)\n >>> predicted_mean, predicted_covariance = kf.predict(mean, covariance)\n \"\"\"\n std_pos = [\n self._std_weight_position * mean[2],\n self._std_weight_position * mean[3],\n self._std_weight_position * mean[2],\n self._std_weight_position * mean[3],\n ]\n std_vel = [\n self._std_weight_velocity * mean[2],\n self._std_weight_velocity * mean[3],\n self._std_weight_velocity * mean[2],\n self._std_weight_velocity * mean[3],\n ]\n motion_cov = np.diag(np.square(np.r_[std_pos, std_vel]))\n\n mean = np.dot(mean, self._motion_mat.T)\n covariance = np.linalg.multi_dot((self._motion_mat, covariance, self._motion_mat.T)) + motion_cov\n\n return mean, covariance\n\n def project(self, mean: np.ndarray, covariance: np.ndarray):\n \"\"\"\n Project state distribution to measurement space.\n\n Args:\n mean (np.ndarray): The state's mean vector (8 dimensional array).\n covariance (np.ndarray): The state's covariance matrix (8x8 dimensional).\n\n Returns:\n mean (np.ndarray): Projected mean of the given state estimate.\n covariance (np.ndarray): Projected covariance matrix of the given state estimate.\n\n Examples:\n >>> kf = KalmanFilterXYWH()\n >>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0])\n >>> covariance = np.eye(8)\n >>> projected_mean, projected_cov = kf.project(mean, covariance)\n \"\"\"\n std = [\n self._std_weight_position * mean[2],\n self._std_weight_position * mean[3],\n self._std_weight_position * mean[2],\n self._std_weight_position * mean[3],\n ]\n innovation_cov = np.diag(np.square(std))\n\n mean = np.dot(self._update_mat, mean)\n covariance = np.linalg.multi_dot((self._update_mat, covariance, self._update_mat.T))\n return mean, covariance + innovation_cov\n\n def multi_predict(self, mean: np.ndarray, covariance: np.ndarray):\n \"\"\"\n Run Kalman filter prediction step (Vectorized version).\n\n Args:\n mean (np.ndarray): The Nx8 dimensional mean matrix of the object states at the previous time step.\n covariance (np.ndarray): The Nx8x8 covariance matrix of the object states at the previous time step.\n\n Returns:\n mean (np.ndarray): Mean matrix of the predicted states with shape (N, 8).\n covariance (np.ndarray): Covariance matrix of the predicted states with shape (N, 8, 8).\n\n Examples:\n >>> mean = np.random.rand(5, 8) # 5 objects with 8-dimensional state vectors\n >>> covariance = np.random.rand(5, 8, 8) # 5 objects with 8x8 covariance matrices\n >>> kf = KalmanFilterXYWH()\n >>> predicted_mean, predicted_covariance = kf.multi_predict(mean, covariance)\n \"\"\"\n std_pos = [\n self._std_weight_position * mean[:, 2],\n self._std_weight_position * mean[:, 3],\n self._std_weight_position * mean[:, 2],\n self._std_weight_position * mean[:, 3],\n ]\n std_vel = [\n self._std_weight_velocity * mean[:, 2],\n self._std_weight_velocity * mean[:, 3],\n self._std_weight_velocity * mean[:, 2],\n self._std_weight_velocity * mean[:, 3],\n ]\n sqr = np.square(np.r_[std_pos, std_vel]).T\n\n motion_cov = [np.diag(sqr[i]) for i in range(len(mean))]\n motion_cov = np.asarray(motion_cov)\n\n mean = np.dot(mean, self._motion_mat.T)\n left = np.dot(self._motion_mat, covariance).transpose((1, 0, 2))\n covariance = np.dot(left, self._motion_mat.T) + motion_cov\n\n return mean, covariance\n\n def update(self, mean: np.ndarray, covariance: np.ndarray, measurement: np.ndarray):\n \"\"\"\n Run Kalman filter correction step.\n\n Args:\n mean (np.ndarray): The predicted state's mean vector (8 dimensional).\n covariance (np.ndarray): The state's covariance matrix (8x8 dimensional).\n measurement (np.ndarray): The 4 dimensional measurement vector (x, y, w, h), where (x, y) is the center\n position, w the width, and h the height of the bounding box.\n\n Returns:\n new_mean (np.ndarray): Measurement-corrected state mean.\n new_covariance (np.ndarray): Measurement-corrected state covariance.\n\n Examples:\n >>> kf = KalmanFilterXYWH()\n >>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0])\n >>> covariance = np.eye(8)\n >>> measurement = np.array([0.5, 0.5, 1.2, 1.2])\n >>> new_mean, new_covariance = kf.update(mean, covariance, measurement)\n \"\"\"\n return super().update(mean, covariance, measurement)", "chunk_type": "class", "name": "KalmanFilterXYWH", "file_path": "ultralytics\\ultralytics\\trackers\\utils\\kalman_filter.py", "start_line": 289, "end_line": 493, "start_col": 0, "end_col": 60, "parent_name": null, "docstring": "A KalmanFilterXYWH class for tracking bounding boxes in image space using a Kalman filter.\n\nImplements a Kalman filter for tracking bounding boxes with state space (x, y, w, h, vx, vy, vw, vh), where\n(x, y) is the center position, w is the width, h is the height, and vx, vy, vw, vh are their respective velocities.\nThe object motion follows a constant velocity model, and the bounding box location (x, y, w, h) is taken as a direct\nobservation of the state space (linear observation model).\n\nAttributes:\n _motion_mat (np.ndarray): The motion matrix for the Kalman filter.\n _update_mat (np.ndarray): The update matrix for the Kalman filter.\n _std_weight_position (float): Standard deviation weight for position.\n _std_weight_velocity (float): Standard deviation weight for velocity.\n\nMethods:\n initiate: Create a track from an unassociated measurement.\n predict: Run the Kalman filter prediction step.\n project: Project the state distribution to measurement space.\n multi_predict: Run the Kalman filter prediction step in a vectorized manner.\n update: Run the Kalman filter correction step.\n\nExamples:\n Create a Kalman filter and initialize a track\n >>> kf = KalmanFilterXYWH()\n >>> measurement = np.array([100, 50, 20, 40])\n >>> mean, covariance = kf.initiate(measurement)\n >>> print(mean)\n >>> print(covariance)", "parameters": null, "return_type": null, "decorators": [], "complexity_score": null, "dependencies": [ "numpy", "scipy.linalg", "KalmanFilterXYAH" ], "chunk_id": "class_KalmanFilterXYWH_509215f6" }, { "content": "import numpy as np", "chunk_type": "import", "name": "numpy", "file_path": "ultralytics\\ultralytics\\trackers\\utils\\matching.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_numpy_46358698" }, { "content": "import scipy", "chunk_type": "import", "name": "scipy", "file_path": "ultralytics\\ultralytics\\trackers\\utils\\matching.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 12, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_scipy_72d84f80" }, { "content": "from scipy.spatial.distance import cdist", "chunk_type": "import", "name": "cdist", "file_path": "ultralytics\\ultralytics\\trackers\\utils\\matching.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 40, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_cdist_40636338" }, { "content": "from ultralytics.utils.metrics import batch_probiou, bbox_ioa", "chunk_type": "import", "name": "batch_probiou, bbox_ioa", "file_path": "ultralytics\\ultralytics\\trackers\\utils\\matching.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 61, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_batch_probiou, bbox_ioa_cc293f11" }, { "content": "def linear_assignment(cost_matrix: np.ndarray, thresh: float, use_lap: bool = True):\n \"\"\"\n Perform linear assignment using either the scipy or lap.lapjv method.\n\n Args:\n cost_matrix (np.ndarray): The matrix containing cost values for assignments, with shape (N, M).\n thresh (float): Threshold for considering an assignment valid.\n use_lap (bool): Use lap.lapjv for the assignment. If False, scipy.optimize.linear_sum_assignment is used.\n\n Returns:\n matched_indices (np.ndarray): Array of matched indices of shape (K, 2), where K is the number of matches.\n unmatched_a (np.ndarray): Array of unmatched indices from the first set, with shape (L,).\n unmatched_b (np.ndarray): Array of unmatched indices from the second set, with shape (M,).\n\n Examples:\n >>> cost_matrix = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])\n >>> thresh = 5.0\n >>> matched_indices, unmatched_a, unmatched_b = linear_assignment(cost_matrix, thresh, use_lap=True)\n \"\"\"\n if cost_matrix.size == 0:\n return np.empty((0, 2), dtype=int), tuple(range(cost_matrix.shape[0])), tuple(range(cost_matrix.shape[1]))\n\n if use_lap:\n # Use lap.lapjv\n # https://github.com/gatagat/lap\n _, x, y = lap.lapjv(cost_matrix, extend_cost=True, cost_limit=thresh)\n matches = [[ix, mx] for ix, mx in enumerate(x) if mx >= 0]\n unmatched_a = np.where(x < 0)[0]\n unmatched_b = np.where(y < 0)[0]\n else:\n # Use scipy.optimize.linear_sum_assignment\n # https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.linear_sum_assignment.html\n x, y = scipy.optimize.linear_sum_assignment(cost_matrix) # row x, col y\n matches = np.asarray([[x[i], y[i]] for i in range(len(x)) if cost_matrix[x[i], y[i]] <= thresh])\n if len(matches) == 0:\n unmatched_a = list(np.arange(cost_matrix.shape[0]))\n unmatched_b = list(np.arange(cost_matrix.shape[1]))\n else:\n unmatched_a = list(frozenset(np.arange(cost_matrix.shape[0])) - frozenset(matches[:, 0]))\n unmatched_b = list(frozenset(np.arange(cost_matrix.shape[1])) - frozenset(matches[:, 1]))\n\n return matches, unmatched_a, unmatched_b", "chunk_type": "function", "name": "linear_assignment", "file_path": "ultralytics\\ultralytics\\trackers\\utils\\matching.py", "start_line": 20, "end_line": 61, "start_col": 0, "end_col": 44, "parent_name": null, "docstring": "Perform linear assignment using either the scipy or lap.lapjv method.\n\nArgs:\n cost_matrix (np.ndarray): The matrix containing cost values for assignments, with shape (N, M).\n thresh (float): Threshold for considering an assignment valid.\n use_lap (bool): Use lap.lapjv for the assignment. If False, scipy.optimize.linear_sum_assignment is used.\n\nReturns:\n matched_indices (np.ndarray): Array of matched indices of shape (K, 2), where K is the number of matches.\n unmatched_a (np.ndarray): Array of unmatched indices from the first set, with shape (L,).\n unmatched_b (np.ndarray): Array of unmatched indices from the second set, with shape (M,).\n\nExamples:\n >>> cost_matrix = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])\n >>> thresh = 5.0\n >>> matched_indices, unmatched_a, unmatched_b = linear_assignment(cost_matrix, thresh, use_lap=True)", "parameters": [ "cost_matrix: np.ndarray", "thresh: float", "use_lap: bool" ], "return_type": null, "decorators": [], "complexity_score": 6, "dependencies": [ "numpy", "scipy", "scipy.spatial.distance.cdist", "ultralytics.utils.metrics.batch_probiou", "ultralytics.utils.metrics.bbox_ioa", "lap", "ultralytics.utils.checks.check_requirements", "lap" ], "chunk_id": "function_linear_assignment_c451a178" }, { "content": "def iou_distance(atracks: list, btracks: list) -> np.ndarray:\n \"\"\"\n Compute cost based on Intersection over Union (IoU) between tracks.\n\n Args:\n atracks (List[STrack] | List[np.ndarray]): List of tracks 'a' or bounding boxes.\n btracks (List[STrack] | List[np.ndarray]): List of tracks 'b' or bounding boxes.\n\n Returns:\n (np.ndarray): Cost matrix computed based on IoU with shape (len(atracks), len(btracks)).\n\n Examples:\n Compute IoU distance between two sets of tracks\n >>> atracks = [np.array([0, 0, 10, 10]), np.array([20, 20, 30, 30])]\n >>> btracks = [np.array([5, 5, 15, 15]), np.array([25, 25, 35, 35])]\n >>> cost_matrix = iou_distance(atracks, btracks)\n \"\"\"\n if atracks and isinstance(atracks[0], np.ndarray) or btracks and isinstance(btracks[0], np.ndarray):\n atlbrs = atracks\n btlbrs = btracks\n else:\n atlbrs = [track.xywha if track.angle is not None else track.xyxy for track in atracks]\n btlbrs = [track.xywha if track.angle is not None else track.xyxy for track in btracks]\n\n ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=np.float32)\n if len(atlbrs) and len(btlbrs):\n if len(atlbrs[0]) == 5 and len(btlbrs[0]) == 5:\n ious = batch_probiou(\n np.ascontiguousarray(atlbrs, dtype=np.float32),\n np.ascontiguousarray(btlbrs, dtype=np.float32),\n ).numpy()\n else:\n ious = bbox_ioa(\n np.ascontiguousarray(atlbrs, dtype=np.float32),\n np.ascontiguousarray(btlbrs, dtype=np.float32),\n iou=True,\n )\n return 1 - ious # cost matrix", "chunk_type": "function", "name": "iou_distance", "file_path": "ultralytics\\ultralytics\\trackers\\utils\\matching.py", "start_line": 64, "end_line": 101, "start_col": 0, "end_col": 19, "parent_name": null, "docstring": "Compute cost based on Intersection over Union (IoU) between tracks.\n\nArgs:\n atracks (List[STrack] | List[np.ndarray]): List of tracks 'a' or bounding boxes.\n btracks (List[STrack] | List[np.ndarray]): List of tracks 'b' or bounding boxes.\n\nReturns:\n (np.ndarray): Cost matrix computed based on IoU with shape (len(atracks), len(btracks)).\n\nExamples:\n Compute IoU distance between two sets of tracks\n >>> atracks = [np.array([0, 0, 10, 10]), np.array([20, 20, 30, 30])]\n >>> btracks = [np.array([5, 5, 15, 15]), np.array([25, 25, 35, 35])]\n >>> cost_matrix = iou_distance(atracks, btracks)", "parameters": [ "atracks: list", "btracks: list" ], "return_type": "np.ndarray", "decorators": [], "complexity_score": 6, "dependencies": [ "numpy", "scipy", "scipy.spatial.distance.cdist", "ultralytics.utils.metrics.batch_probiou", "ultralytics.utils.metrics.bbox_ioa", "lap", "ultralytics.utils.checks.check_requirements", "lap" ], "chunk_id": "function_iou_distance_046b5443" }, { "content": "def embedding_distance(tracks: list, detections: list, metric: str = \"cosine\") -> np.ndarray:\n \"\"\"\n Compute distance between tracks and detections based on embeddings.\n\n Args:\n tracks (List[STrack]): List of tracks, where each track contains embedding features.\n detections (List[BaseTrack]): List of detections, where each detection contains embedding features.\n metric (str): Metric for distance computation. Supported metrics include 'cosine', 'euclidean', etc.\n\n Returns:\n (np.ndarray): Cost matrix computed based on embeddings with shape (N, M), where N is the number of tracks\n and M is the number of detections.\n\n Examples:\n Compute the embedding distance between tracks and detections using cosine metric\n >>> tracks = [STrack(...), STrack(...)] # List of track objects with embedding features\n >>> detections = [BaseTrack(...), BaseTrack(...)] # List of detection objects with embedding features\n >>> cost_matrix = embedding_distance(tracks, detections, metric=\"cosine\")\n \"\"\"\n cost_matrix = np.zeros((len(tracks), len(detections)), dtype=np.float32)\n if cost_matrix.size == 0:\n return cost_matrix\n det_features = np.asarray([track.curr_feat for track in detections], dtype=np.float32)\n # for i, track in enumerate(tracks):\n # cost_matrix[i, :] = np.maximum(0.0, cdist(track.smooth_feat.reshape(1,-1), det_features, metric))\n track_features = np.asarray([track.smooth_feat for track in tracks], dtype=np.float32)\n cost_matrix = np.maximum(0.0, cdist(track_features, det_features, metric)) # Normalized features\n return cost_matrix", "chunk_type": "function", "name": "embedding_distance", "file_path": "ultralytics\\ultralytics\\trackers\\utils\\matching.py", "start_line": 104, "end_line": 131, "start_col": 0, "end_col": 22, "parent_name": null, "docstring": "Compute distance between tracks and detections based on embeddings.\n\nArgs:\n tracks (List[STrack]): List of tracks, where each track contains embedding features.\n detections (List[BaseTrack]): List of detections, where each detection contains embedding features.\n metric (str): Metric for distance computation. Supported metrics include 'cosine', 'euclidean', etc.\n\nReturns:\n (np.ndarray): Cost matrix computed based on embeddings with shape (N, M), where N is the number of tracks\n and M is the number of detections.\n\nExamples:\n Compute the embedding distance between tracks and detections using cosine metric\n >>> tracks = [STrack(...), STrack(...)] # List of track objects with embedding features\n >>> detections = [BaseTrack(...), BaseTrack(...)] # List of detection objects with embedding features\n >>> cost_matrix = embedding_distance(tracks, detections, metric=\"cosine\")", "parameters": [ "tracks: list", "detections: list", "metric: str" ], "return_type": "np.ndarray", "decorators": [], "complexity_score": 4, "dependencies": [ "numpy", "scipy", "scipy.spatial.distance.cdist", "ultralytics.utils.metrics.batch_probiou", "ultralytics.utils.metrics.bbox_ioa", "lap", "ultralytics.utils.checks.check_requirements", "lap" ], "chunk_id": "function_embedding_distance_5fb9ec69" }, { "content": "def fuse_score(cost_matrix: np.ndarray, detections: list) -> np.ndarray:\n \"\"\"\n Fuse cost matrix with detection scores to produce a single similarity matrix.\n\n Args:\n cost_matrix (np.ndarray): The matrix containing cost values for assignments, with shape (N, M).\n detections (List[BaseTrack]): List of detections, each containing a score attribute.\n\n Returns:\n (np.ndarray): Fused similarity matrix with shape (N, M).\n\n Examples:\n Fuse a cost matrix with detection scores\n >>> cost_matrix = np.random.rand(5, 10) # 5 tracks and 10 detections\n >>> detections = [BaseTrack(score=np.random.rand()) for _ in range(10)]\n >>> fused_matrix = fuse_score(cost_matrix, detections)\n \"\"\"\n if cost_matrix.size == 0:\n return cost_matrix\n iou_sim = 1 - cost_matrix\n det_scores = np.array([det.score for det in detections])\n det_scores = np.expand_dims(det_scores, axis=0).repeat(cost_matrix.shape[0], axis=0)\n fuse_sim = iou_sim * det_scores\n return 1 - fuse_sim # fuse_cost", "chunk_type": "function", "name": "fuse_score", "file_path": "ultralytics\\ultralytics\\trackers\\utils\\matching.py", "start_line": 134, "end_line": 157, "start_col": 0, "end_col": 23, "parent_name": null, "docstring": "Fuse cost matrix with detection scores to produce a single similarity matrix.\n\nArgs:\n cost_matrix (np.ndarray): The matrix containing cost values for assignments, with shape (N, M).\n detections (List[BaseTrack]): List of detections, each containing a score attribute.\n\nReturns:\n (np.ndarray): Fused similarity matrix with shape (N, M).\n\nExamples:\n Fuse a cost matrix with detection scores\n >>> cost_matrix = np.random.rand(5, 10) # 5 tracks and 10 detections\n >>> detections = [BaseTrack(score=np.random.rand()) for _ in range(10)]\n >>> fused_matrix = fuse_score(cost_matrix, detections)", "parameters": [ "cost_matrix: np.ndarray", "detections: list" ], "return_type": "np.ndarray", "decorators": [], "complexity_score": 3, "dependencies": [ "numpy", "scipy", "scipy.spatial.distance.cdist", "ultralytics.utils.metrics.batch_probiou", "ultralytics.utils.metrics.bbox_ioa", "lap", "ultralytics.utils.checks.check_requirements", "lap" ], "chunk_id": "function_fuse_score_2d29eb18" }, { "content": "from collections import defaultdict", "chunk_type": "import", "name": "defaultdict", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\base.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 35, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_defaultdict_32412c51" }, { "content": "from copy import deepcopy", "chunk_type": "import", "name": "deepcopy", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\base.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 25, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_deepcopy_fb9df1a7" }, { "content": "def on_pretrain_routine_start(trainer):\n \"\"\"Called before the pretraining routine starts.\"\"\"\n pass", "chunk_type": "function", "name": "on_pretrain_routine_start", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\base.py", "start_line": 10, "end_line": 12, "start_col": 0, "end_col": 8, "parent_name": null, "docstring": "Called before the pretraining routine starts.", "parameters": [ "trainer" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "collections.defaultdict", "copy.deepcopy", "hub.callbacks", "clearml.callbacks", "comet.callbacks", "dvc.callbacks", "mlflow.callbacks", "neptune.callbacks", "raytune.callbacks", "tensorboard.callbacks", "wb.callbacks" ], "chunk_id": "function_on_pretrain_routine_start_f6e22321" }, { "content": "def on_pretrain_routine_end(trainer):\n \"\"\"Called after the pretraining routine ends.\"\"\"\n pass", "chunk_type": "function", "name": "on_pretrain_routine_end", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\base.py", "start_line": 15, "end_line": 17, "start_col": 0, "end_col": 8, "parent_name": null, "docstring": "Called after the pretraining routine ends.", "parameters": [ "trainer" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "collections.defaultdict", "copy.deepcopy", "hub.callbacks", "clearml.callbacks", "comet.callbacks", "dvc.callbacks", "mlflow.callbacks", "neptune.callbacks", "raytune.callbacks", "tensorboard.callbacks", "wb.callbacks" ], "chunk_id": "function_on_pretrain_routine_end_82d57cc4" }, { "content": "def on_train_start(trainer):\n \"\"\"Called when the training starts.\"\"\"\n pass", "chunk_type": "function", "name": "on_train_start", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\base.py", "start_line": 20, "end_line": 22, "start_col": 0, "end_col": 8, "parent_name": null, "docstring": "Called when the training starts.", "parameters": [ "trainer" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "collections.defaultdict", "copy.deepcopy", "hub.callbacks", "clearml.callbacks", "comet.callbacks", "dvc.callbacks", "mlflow.callbacks", "neptune.callbacks", "raytune.callbacks", "tensorboard.callbacks", "wb.callbacks" ], "chunk_id": "function_on_train_start_678ff1f2" }, { "content": "def on_train_epoch_start(trainer):\n \"\"\"Called at the start of each training epoch.\"\"\"\n pass", "chunk_type": "function", "name": "on_train_epoch_start", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\base.py", "start_line": 25, "end_line": 27, "start_col": 0, "end_col": 8, "parent_name": null, "docstring": "Called at the start of each training epoch.", "parameters": [ "trainer" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "collections.defaultdict", "copy.deepcopy", "hub.callbacks", "clearml.callbacks", "comet.callbacks", "dvc.callbacks", "mlflow.callbacks", "neptune.callbacks", "raytune.callbacks", "tensorboard.callbacks", "wb.callbacks" ], "chunk_id": "function_on_train_epoch_start_c5ea4d1e" }, { "content": "def on_train_batch_start(trainer):\n \"\"\"Called at the start of each training batch.\"\"\"\n pass", "chunk_type": "function", "name": "on_train_batch_start", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\base.py", "start_line": 30, "end_line": 32, "start_col": 0, "end_col": 8, "parent_name": null, "docstring": "Called at the start of each training batch.", "parameters": [ "trainer" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "collections.defaultdict", "copy.deepcopy", "hub.callbacks", "clearml.callbacks", "comet.callbacks", "dvc.callbacks", "mlflow.callbacks", "neptune.callbacks", "raytune.callbacks", "tensorboard.callbacks", "wb.callbacks" ], "chunk_id": "function_on_train_batch_start_1b5f49c6" }, { "content": "def optimizer_step(trainer):\n \"\"\"Called when the optimizer takes a step.\"\"\"\n pass", "chunk_type": "function", "name": "optimizer_step", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\base.py", "start_line": 35, "end_line": 37, "start_col": 0, "end_col": 8, "parent_name": null, "docstring": "Called when the optimizer takes a step.", "parameters": [ "trainer" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "collections.defaultdict", "copy.deepcopy", "hub.callbacks", "clearml.callbacks", "comet.callbacks", "dvc.callbacks", "mlflow.callbacks", "neptune.callbacks", "raytune.callbacks", "tensorboard.callbacks", "wb.callbacks" ], "chunk_id": "function_optimizer_step_26375cc8" }, { "content": "def on_before_zero_grad(trainer):\n \"\"\"Called before the gradients are set to zero.\"\"\"\n pass", "chunk_type": "function", "name": "on_before_zero_grad", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\base.py", "start_line": 40, "end_line": 42, "start_col": 0, "end_col": 8, "parent_name": null, "docstring": "Called before the gradients are set to zero.", "parameters": [ "trainer" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "collections.defaultdict", "copy.deepcopy", "hub.callbacks", "clearml.callbacks", "comet.callbacks", "dvc.callbacks", "mlflow.callbacks", "neptune.callbacks", "raytune.callbacks", "tensorboard.callbacks", "wb.callbacks" ], "chunk_id": "function_on_before_zero_grad_460f4a5b" }, { "content": "def on_train_batch_end(trainer):\n \"\"\"Called at the end of each training batch.\"\"\"\n pass", "chunk_type": "function", "name": "on_train_batch_end", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\base.py", "start_line": 45, "end_line": 47, "start_col": 0, "end_col": 8, "parent_name": null, "docstring": "Called at the end of each training batch.", "parameters": [ "trainer" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "collections.defaultdict", "copy.deepcopy", "hub.callbacks", "clearml.callbacks", "comet.callbacks", "dvc.callbacks", "mlflow.callbacks", "neptune.callbacks", "raytune.callbacks", "tensorboard.callbacks", "wb.callbacks" ], "chunk_id": "function_on_train_batch_end_00f24bdc" }, { "content": "def on_train_epoch_end(trainer):\n \"\"\"Called at the end of each training epoch.\"\"\"\n pass", "chunk_type": "function", "name": "on_train_epoch_end", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\base.py", "start_line": 50, "end_line": 52, "start_col": 0, "end_col": 8, "parent_name": null, "docstring": "Called at the end of each training epoch.", "parameters": [ "trainer" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "collections.defaultdict", "copy.deepcopy", "hub.callbacks", "clearml.callbacks", "comet.callbacks", "dvc.callbacks", "mlflow.callbacks", "neptune.callbacks", "raytune.callbacks", "tensorboard.callbacks", "wb.callbacks" ], "chunk_id": "function_on_train_epoch_end_604d33e9" }, { "content": "def on_fit_epoch_end(trainer):\n \"\"\"Called at the end of each fit epoch (train + val).\"\"\"\n pass", "chunk_type": "function", "name": "on_fit_epoch_end", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\base.py", "start_line": 55, "end_line": 57, "start_col": 0, "end_col": 8, "parent_name": null, "docstring": "Called at the end of each fit epoch (train + val).", "parameters": [ "trainer" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "collections.defaultdict", "copy.deepcopy", "hub.callbacks", "clearml.callbacks", "comet.callbacks", "dvc.callbacks", "mlflow.callbacks", "neptune.callbacks", "raytune.callbacks", "tensorboard.callbacks", "wb.callbacks" ], "chunk_id": "function_on_fit_epoch_end_0025f6b1" }, { "content": "def on_model_save(trainer):\n \"\"\"Called when the model is saved.\"\"\"\n pass", "chunk_type": "function", "name": "on_model_save", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\base.py", "start_line": 60, "end_line": 62, "start_col": 0, "end_col": 8, "parent_name": null, "docstring": "Called when the model is saved.", "parameters": [ "trainer" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "collections.defaultdict", "copy.deepcopy", "hub.callbacks", "clearml.callbacks", "comet.callbacks", "dvc.callbacks", "mlflow.callbacks", "neptune.callbacks", "raytune.callbacks", "tensorboard.callbacks", "wb.callbacks" ], "chunk_id": "function_on_model_save_14d6ee39" }, { "content": "def on_train_end(trainer):\n \"\"\"Called when the training ends.\"\"\"\n pass", "chunk_type": "function", "name": "on_train_end", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\base.py", "start_line": 65, "end_line": 67, "start_col": 0, "end_col": 8, "parent_name": null, "docstring": "Called when the training ends.", "parameters": [ "trainer" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "collections.defaultdict", "copy.deepcopy", "hub.callbacks", "clearml.callbacks", "comet.callbacks", "dvc.callbacks", "mlflow.callbacks", "neptune.callbacks", "raytune.callbacks", "tensorboard.callbacks", "wb.callbacks" ], "chunk_id": "function_on_train_end_019e91c8" }, { "content": "def on_params_update(trainer):\n \"\"\"Called when the model parameters are updated.\"\"\"\n pass", "chunk_type": "function", "name": "on_params_update", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\base.py", "start_line": 70, "end_line": 72, "start_col": 0, "end_col": 8, "parent_name": null, "docstring": "Called when the model parameters are updated.", "parameters": [ "trainer" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "collections.defaultdict", "copy.deepcopy", "hub.callbacks", "clearml.callbacks", "comet.callbacks", "dvc.callbacks", "mlflow.callbacks", "neptune.callbacks", "raytune.callbacks", "tensorboard.callbacks", "wb.callbacks" ], "chunk_id": "function_on_params_update_aae0e6c4" }, { "content": "def teardown(trainer):\n \"\"\"Called during the teardown of the training process.\"\"\"\n pass", "chunk_type": "function", "name": "teardown", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\base.py", "start_line": 75, "end_line": 77, "start_col": 0, "end_col": 8, "parent_name": null, "docstring": "Called during the teardown of the training process.", "parameters": [ "trainer" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "collections.defaultdict", "copy.deepcopy", "hub.callbacks", "clearml.callbacks", "comet.callbacks", "dvc.callbacks", "mlflow.callbacks", "neptune.callbacks", "raytune.callbacks", "tensorboard.callbacks", "wb.callbacks" ], "chunk_id": "function_teardown_2d541efb" }, { "content": "def on_val_start(validator):\n \"\"\"Called when the validation starts.\"\"\"\n pass", "chunk_type": "function", "name": "on_val_start", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\base.py", "start_line": 83, "end_line": 85, "start_col": 0, "end_col": 8, "parent_name": null, "docstring": "Called when the validation starts.", "parameters": [ "validator" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "collections.defaultdict", "copy.deepcopy", "hub.callbacks", "clearml.callbacks", "comet.callbacks", "dvc.callbacks", "mlflow.callbacks", "neptune.callbacks", "raytune.callbacks", "tensorboard.callbacks", "wb.callbacks" ], "chunk_id": "function_on_val_start_8e1a914d" }, { "content": "def on_val_batch_start(validator):\n \"\"\"Called at the start of each validation batch.\"\"\"\n pass", "chunk_type": "function", "name": "on_val_batch_start", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\base.py", "start_line": 88, "end_line": 90, "start_col": 0, "end_col": 8, "parent_name": null, "docstring": "Called at the start of each validation batch.", "parameters": [ "validator" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "collections.defaultdict", "copy.deepcopy", "hub.callbacks", "clearml.callbacks", "comet.callbacks", "dvc.callbacks", "mlflow.callbacks", "neptune.callbacks", "raytune.callbacks", "tensorboard.callbacks", "wb.callbacks" ], "chunk_id": "function_on_val_batch_start_5dd04932" }, { "content": "def on_val_batch_end(validator):\n \"\"\"Called at the end of each validation batch.\"\"\"\n pass", "chunk_type": "function", "name": "on_val_batch_end", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\base.py", "start_line": 93, "end_line": 95, "start_col": 0, "end_col": 8, "parent_name": null, "docstring": "Called at the end of each validation batch.", "parameters": [ "validator" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "collections.defaultdict", "copy.deepcopy", "hub.callbacks", "clearml.callbacks", "comet.callbacks", "dvc.callbacks", "mlflow.callbacks", "neptune.callbacks", "raytune.callbacks", "tensorboard.callbacks", "wb.callbacks" ], "chunk_id": "function_on_val_batch_end_9c0eee9c" }, { "content": "def on_val_end(validator):\n \"\"\"Called when the validation ends.\"\"\"\n pass", "chunk_type": "function", "name": "on_val_end", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\base.py", "start_line": 98, "end_line": 100, "start_col": 0, "end_col": 8, "parent_name": null, "docstring": "Called when the validation ends.", "parameters": [ "validator" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "collections.defaultdict", "copy.deepcopy", "hub.callbacks", "clearml.callbacks", "comet.callbacks", "dvc.callbacks", "mlflow.callbacks", "neptune.callbacks", "raytune.callbacks", "tensorboard.callbacks", "wb.callbacks" ], "chunk_id": "function_on_val_end_bd55536d" }, { "content": "def on_predict_start(predictor):\n \"\"\"Called when the prediction starts.\"\"\"\n pass", "chunk_type": "function", "name": "on_predict_start", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\base.py", "start_line": 106, "end_line": 108, "start_col": 0, "end_col": 8, "parent_name": null, "docstring": "Called when the prediction starts.", "parameters": [ "predictor" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "collections.defaultdict", "copy.deepcopy", "hub.callbacks", "clearml.callbacks", "comet.callbacks", "dvc.callbacks", "mlflow.callbacks", "neptune.callbacks", "raytune.callbacks", "tensorboard.callbacks", "wb.callbacks" ], "chunk_id": "function_on_predict_start_7f9e42dc" }, { "content": "def on_predict_batch_start(predictor):\n \"\"\"Called at the start of each prediction batch.\"\"\"\n pass", "chunk_type": "function", "name": "on_predict_batch_start", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\base.py", "start_line": 111, "end_line": 113, "start_col": 0, "end_col": 8, "parent_name": null, "docstring": "Called at the start of each prediction batch.", "parameters": [ "predictor" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "collections.defaultdict", "copy.deepcopy", "hub.callbacks", "clearml.callbacks", "comet.callbacks", "dvc.callbacks", "mlflow.callbacks", "neptune.callbacks", "raytune.callbacks", "tensorboard.callbacks", "wb.callbacks" ], "chunk_id": "function_on_predict_batch_start_f9aaf040" }, { "content": "def on_predict_batch_end(predictor):\n \"\"\"Called at the end of each prediction batch.\"\"\"\n pass", "chunk_type": "function", "name": "on_predict_batch_end", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\base.py", "start_line": 116, "end_line": 118, "start_col": 0, "end_col": 8, "parent_name": null, "docstring": "Called at the end of each prediction batch.", "parameters": [ "predictor" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "collections.defaultdict", "copy.deepcopy", "hub.callbacks", "clearml.callbacks", "comet.callbacks", "dvc.callbacks", "mlflow.callbacks", "neptune.callbacks", "raytune.callbacks", "tensorboard.callbacks", "wb.callbacks" ], "chunk_id": "function_on_predict_batch_end_0ae684b0" }, { "content": "def on_predict_postprocess_end(predictor):\n \"\"\"Called after the post-processing of the prediction ends.\"\"\"\n pass", "chunk_type": "function", "name": "on_predict_postprocess_end", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\base.py", "start_line": 121, "end_line": 123, "start_col": 0, "end_col": 8, "parent_name": null, "docstring": "Called after the post-processing of the prediction ends.", "parameters": [ "predictor" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "collections.defaultdict", "copy.deepcopy", "hub.callbacks", "clearml.callbacks", "comet.callbacks", "dvc.callbacks", "mlflow.callbacks", "neptune.callbacks", "raytune.callbacks", "tensorboard.callbacks", "wb.callbacks" ], "chunk_id": "function_on_predict_postprocess_end_357c0141" }, { "content": "def on_predict_end(predictor):\n \"\"\"Called when the prediction ends.\"\"\"\n pass", "chunk_type": "function", "name": "on_predict_end", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\base.py", "start_line": 126, "end_line": 128, "start_col": 0, "end_col": 8, "parent_name": null, "docstring": "Called when the prediction ends.", "parameters": [ "predictor" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "collections.defaultdict", "copy.deepcopy", "hub.callbacks", "clearml.callbacks", "comet.callbacks", "dvc.callbacks", "mlflow.callbacks", "neptune.callbacks", "raytune.callbacks", "tensorboard.callbacks", "wb.callbacks" ], "chunk_id": "function_on_predict_end_9942070b" }, { "content": "def on_export_start(exporter):\n \"\"\"Called when the model export starts.\"\"\"\n pass", "chunk_type": "function", "name": "on_export_start", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\base.py", "start_line": 134, "end_line": 136, "start_col": 0, "end_col": 8, "parent_name": null, "docstring": "Called when the model export starts.", "parameters": [ "exporter" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "collections.defaultdict", "copy.deepcopy", "hub.callbacks", "clearml.callbacks", "comet.callbacks", "dvc.callbacks", "mlflow.callbacks", "neptune.callbacks", "raytune.callbacks", "tensorboard.callbacks", "wb.callbacks" ], "chunk_id": "function_on_export_start_d6541e30" }, { "content": "def on_export_end(exporter):\n \"\"\"Called when the model export ends.\"\"\"\n pass", "chunk_type": "function", "name": "on_export_end", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\base.py", "start_line": 139, "end_line": 141, "start_col": 0, "end_col": 8, "parent_name": null, "docstring": "Called when the model export ends.", "parameters": [ "exporter" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "collections.defaultdict", "copy.deepcopy", "hub.callbacks", "clearml.callbacks", "comet.callbacks", "dvc.callbacks", "mlflow.callbacks", "neptune.callbacks", "raytune.callbacks", "tensorboard.callbacks", "wb.callbacks" ], "chunk_id": "function_on_export_end_6fd84381" }, { "content": "default_callbacks = {\n # Run in trainer\n \"on_pretrain_routine_start\": [on_pretrain_routine_start],\n \"on_pretrain_routine_end\": [on_pretrain_routine_end],\n \"on_train_start\": [on_train_start],\n \"on_train_epoch_start\": [on_train_epoch_start],\n \"on_train_batch_start\": [on_train_batch_start],\n \"optimizer_step\": [optimizer_step],\n \"on_before_zero_grad\": [on_before_zero_grad],\n \"on_train_batch_end\": [on_train_batch_end],\n \"on_train_epoch_end\": [on_train_epoch_end],\n \"on_fit_epoch_end\": [on_fit_epoch_end], # fit = train + val\n \"on_model_save\": [on_model_save],\n \"on_train_end\": [on_train_end],\n \"on_params_update\": [on_params_update],\n \"teardown\": [teardown],\n # Run in validator\n \"on_val_start\": [on_val_start],\n \"on_val_batch_start\": [on_val_batch_start],\n \"on_val_batch_end\": [on_val_batch_end],\n \"on_val_end\": [on_val_end],\n # Run in predictor\n \"on_predict_start\": [on_predict_start],\n \"on_predict_batch_start\": [on_predict_batch_start],\n \"on_predict_postprocess_end\": [on_predict_postprocess_end],\n \"on_predict_batch_end\": [on_predict_batch_end],\n \"on_predict_end\": [on_predict_end],\n # Run in exporter\n \"on_export_start\": [on_export_start],\n \"on_export_end\": [on_export_end],\n}", "chunk_type": "variable", "name": "default_callbacks", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\base.py", "start_line": 144, "end_line": 174, "start_col": 0, "end_col": 1, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_default_callbacks_53b1012c" }, { "content": "def get_default_callbacks():\n \"\"\"\n Get the default callbacks for Ultralytics training, validation, prediction, and export processes.\n\n Returns:\n (dict): Dictionary of default callbacks for various training events. Each key represents an event during the\n training process, and the corresponding value is a list of callback functions executed when that event\n occurs.\n\n Examples:\n >>> callbacks = get_default_callbacks()\n >>> print(list(callbacks.keys())) # show all available callback events\n ['on_pretrain_routine_start', 'on_pretrain_routine_end', ...]\n \"\"\"\n return defaultdict(list, deepcopy(default_callbacks))", "chunk_type": "function", "name": "get_default_callbacks", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\base.py", "start_line": 177, "end_line": 191, "start_col": 0, "end_col": 57, "parent_name": null, "docstring": "Get the default callbacks for Ultralytics training, validation, prediction, and export processes.\n\nReturns:\n (dict): Dictionary of default callbacks for various training events. Each key represents an event during the\n training process, and the corresponding value is a list of callback functions executed when that event\n occurs.\n\nExamples:\n >>> callbacks = get_default_callbacks()\n >>> print(list(callbacks.keys())) # show all available callback events\n ['on_pretrain_routine_start', 'on_pretrain_routine_end', ...]", "parameters": [], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "collections.defaultdict", "copy.deepcopy", "hub.callbacks", "clearml.callbacks", "comet.callbacks", "dvc.callbacks", "mlflow.callbacks", "neptune.callbacks", "raytune.callbacks", "tensorboard.callbacks", "wb.callbacks" ], "chunk_id": "function_get_default_callbacks_eea37bf9" }, { "content": "def add_integration_callbacks(instance):\n \"\"\"\n Add integration callbacks to the instance's callbacks dictionary.\n\n This function loads and adds various integration callbacks to the provided instance. The specific callbacks added\n depend on the type of instance provided. All instances receive HUB callbacks, while Trainer instances also receive\n additional callbacks for various integrations like ClearML, Comet, DVC, MLflow, Neptune, Ray Tune, TensorBoard,\n and Weights & Biases.\n\n Args:\n instance (Trainer | Predictor | Validator | Exporter): The object instance to which callbacks will be added.\n The type of instance determines which callbacks are loaded.\n\n Examples:\n >>> from ultralytics.engine.trainer import BaseTrainer\n >>> trainer = BaseTrainer()\n >>> add_integration_callbacks(trainer)\n \"\"\"\n # Load HUB callbacks\n from .hub import callbacks as hub_cb\n\n callbacks_list = [hub_cb]\n\n # Load training callbacks\n if \"Trainer\" in instance.__class__.__name__:\n from .clearml import callbacks as clear_cb\n from .comet import callbacks as comet_cb\n from .dvc import callbacks as dvc_cb\n from .mlflow import callbacks as mlflow_cb\n from .neptune import callbacks as neptune_cb\n from .raytune import callbacks as tune_cb\n from .tensorboard import callbacks as tb_cb\n from .wb import callbacks as wb_cb\n\n callbacks_list.extend([clear_cb, comet_cb, dvc_cb, mlflow_cb, neptune_cb, tune_cb, tb_cb, wb_cb])\n\n # Add the callbacks to the callbacks dictionary\n for callbacks in callbacks_list:\n for k, v in callbacks.items():\n if v not in instance.callbacks[k]:\n instance.callbacks[k].append(v)", "chunk_type": "function", "name": "add_integration_callbacks", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\base.py", "start_line": 194, "end_line": 234, "start_col": 0, "end_col": 47, "parent_name": null, "docstring": "Add integration callbacks to the instance's callbacks dictionary.\n\nThis function loads and adds various integration callbacks to the provided instance. The specific callbacks added\ndepend on the type of instance provided. All instances receive HUB callbacks, while Trainer instances also receive\nadditional callbacks for various integrations like ClearML, Comet, DVC, MLflow, Neptune, Ray Tune, TensorBoard,\nand Weights & Biases.\n\nArgs:\n instance (Trainer | Predictor | Validator | Exporter): The object instance to which callbacks will be added.\n The type of instance determines which callbacks are loaded.\n\nExamples:\n >>> from ultralytics.engine.trainer import BaseTrainer\n >>> trainer = BaseTrainer()\n >>> add_integration_callbacks(trainer)", "parameters": [ "instance" ], "return_type": null, "decorators": [], "complexity_score": 5, "dependencies": [ "collections.defaultdict", "copy.deepcopy", "hub.callbacks", "clearml.callbacks", "comet.callbacks", "dvc.callbacks", "mlflow.callbacks", "neptune.callbacks", "raytune.callbacks", "tensorboard.callbacks", "wb.callbacks" ], "chunk_id": "function_add_integration_callbacks_5049c2f7" }, { "content": "from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING", "chunk_type": "import", "name": "LOGGER, SETTINGS, TESTS_RUNNING", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\clearml.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 61, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_LOGGER, SETTINGS, TESTS_RUNNING_f460cad7" }, { "content": "def _log_debug_samples(files, title: str = \"Debug Samples\") -> None:\n \"\"\"\n Log files (images) as debug samples in the ClearML task.\n\n Args:\n files (List[Path]): A list of file paths in PosixPath format.\n title (str): A title that groups together images with the same values.\n \"\"\"\n import re\n\n if task := Task.current_task():\n for f in files:\n if f.exists():\n it = re.search(r\"_batch(\\d+)\", f.name)\n iteration = int(it.groups()[0]) if it else 0\n task.get_logger().report_image(\n title=title, series=f.name.replace(it.group(), \"\"), local_path=str(f), iteration=iteration\n )", "chunk_type": "function", "name": "_log_debug_samples", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\clearml.py", "start_line": 17, "end_line": 34, "start_col": 0, "end_col": 17, "parent_name": null, "docstring": "Log files (images) as debug samples in the ClearML task.\n\nArgs:\n files (List[Path]): A list of file paths in PosixPath format.\n title (str): A title that groups together images with the same values.", "parameters": [ "files", "title: str" ], "return_type": "None", "decorators": [], "complexity_score": 4, "dependencies": [ "ultralytics.utils.LOGGER", "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "clearml", "clearml.Task", "re", "matplotlib.image", "matplotlib.pyplot", "clearml.binding.frameworks.pytorch_bind.PatchPyTorchModelIO", "clearml.binding.matplotlib_bind.PatchedMatplotlib", "ultralytics.utils.torch_utils.model_info_for_loggers" ], "chunk_id": "function__log_debug_samples_f955ae87" }, { "content": "def _log_plot(title: str, plot_path: str) -> None:\n \"\"\"\n Log an image as a plot in the plot section of ClearML.\n\n Args:\n title (str): The title of the plot.\n plot_path (str): The path to the saved image file.\n \"\"\"\n import matplotlib.image as mpimg\n import matplotlib.pyplot as plt\n\n img = mpimg.imread(plot_path)\n fig = plt.figure()\n ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect=\"auto\", xticks=[], yticks=[]) # no ticks\n ax.imshow(img)\n\n Task.current_task().get_logger().report_matplotlib_figure(\n title=title, series=\"\", figure=fig, report_interactive=False\n )", "chunk_type": "function", "name": "_log_plot", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\clearml.py", "start_line": 37, "end_line": 55, "start_col": 0, "end_col": 5, "parent_name": null, "docstring": "Log an image as a plot in the plot section of ClearML.\n\nArgs:\n title (str): The title of the plot.\n plot_path (str): The path to the saved image file.", "parameters": [ "title: str", "plot_path: str" ], "return_type": "None", "decorators": [], "complexity_score": 1, "dependencies": [ "ultralytics.utils.LOGGER", "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "clearml", "clearml.Task", "re", "matplotlib.image", "matplotlib.pyplot", "clearml.binding.frameworks.pytorch_bind.PatchPyTorchModelIO", "clearml.binding.matplotlib_bind.PatchedMatplotlib", "ultralytics.utils.torch_utils.model_info_for_loggers" ], "chunk_id": "function__log_plot_0104fe1a" }, { "content": "def on_pretrain_routine_start(trainer) -> None:\n \"\"\"Initialize and connect ClearML task at the start of pretraining routine.\"\"\"\n try:\n if task := Task.current_task():\n # WARNING: make sure the automatic pytorch and matplotlib bindings are disabled!\n # We are logging these plots and model files manually in the integration\n from clearml.binding.frameworks.pytorch_bind import PatchPyTorchModelIO\n from clearml.binding.matplotlib_bind import PatchedMatplotlib\n\n PatchPyTorchModelIO.update_current_task(None)\n PatchedMatplotlib.update_current_task(None)\n else:\n task = Task.init(\n project_name=trainer.args.project or \"Ultralytics\",\n task_name=trainer.args.name,\n tags=[\"Ultralytics\"],\n output_uri=True,\n reuse_last_task_id=False,\n auto_connect_frameworks={\"pytorch\": False, \"matplotlib\": False},\n )\n LOGGER.warning(\n \"ClearML Initialized a new task. If you want to run remotely, \"\n \"please add clearml-init and connect your arguments before initializing YOLO.\"\n )\n task.connect(vars(trainer.args), name=\"General\")\n except Exception as e:\n LOGGER.warning(f\"ClearML installed but not initialized correctly, not logging this run. {e}\")", "chunk_type": "function", "name": "on_pretrain_routine_start", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\clearml.py", "start_line": 58, "end_line": 84, "start_col": 0, "end_col": 101, "parent_name": null, "docstring": "Initialize and connect ClearML task at the start of pretraining routine.", "parameters": [ "trainer" ], "return_type": "None", "decorators": [], "complexity_score": 3, "dependencies": [ "ultralytics.utils.LOGGER", "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "clearml", "clearml.Task", "re", "matplotlib.image", "matplotlib.pyplot", "clearml.binding.frameworks.pytorch_bind.PatchPyTorchModelIO", "clearml.binding.matplotlib_bind.PatchedMatplotlib", "ultralytics.utils.torch_utils.model_info_for_loggers" ], "chunk_id": "function_on_pretrain_routine_start_5ce93a7c" }, { "content": "def on_train_epoch_end(trainer) -> None:\n \"\"\"Log debug samples for the first epoch and report current training progress.\"\"\"\n if task := Task.current_task():\n # Log debug samples for first epoch only\n if trainer.epoch == 1:\n _log_debug_samples(sorted(trainer.save_dir.glob(\"train_batch*.jpg\")), \"Mosaic\")\n # Report the current training progress\n for k, v in trainer.label_loss_items(trainer.tloss, prefix=\"train\").items():\n task.get_logger().report_scalar(\"train\", k, v, iteration=trainer.epoch)\n for k, v in trainer.lr.items():\n task.get_logger().report_scalar(\"lr\", k, v, iteration=trainer.epoch)", "chunk_type": "function", "name": "on_train_epoch_end", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\clearml.py", "start_line": 87, "end_line": 97, "start_col": 0, "end_col": 80, "parent_name": null, "docstring": "Log debug samples for the first epoch and report current training progress.", "parameters": [ "trainer" ], "return_type": "None", "decorators": [], "complexity_score": 5, "dependencies": [ "ultralytics.utils.LOGGER", "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "clearml", "clearml.Task", "re", "matplotlib.image", "matplotlib.pyplot", "clearml.binding.frameworks.pytorch_bind.PatchPyTorchModelIO", "clearml.binding.matplotlib_bind.PatchedMatplotlib", "ultralytics.utils.torch_utils.model_info_for_loggers" ], "chunk_id": "function_on_train_epoch_end_152ca3b7" }, { "content": "def on_fit_epoch_end(trainer) -> None:\n \"\"\"Report model information and metrics to logger at the end of an epoch.\"\"\"\n if task := Task.current_task():\n # Report epoch time and validation metrics\n task.get_logger().report_scalar(\n title=\"Epoch Time\", series=\"Epoch Time\", value=trainer.epoch_time, iteration=trainer.epoch\n )\n for k, v in trainer.metrics.items():\n task.get_logger().report_scalar(\"val\", k, v, iteration=trainer.epoch)\n if trainer.epoch == 0:\n from ultralytics.utils.torch_utils import model_info_for_loggers\n\n for k, v in model_info_for_loggers(trainer).items():\n task.get_logger().report_single_value(k, v)", "chunk_type": "function", "name": "on_fit_epoch_end", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\clearml.py", "start_line": 100, "end_line": 113, "start_col": 0, "end_col": 59, "parent_name": null, "docstring": "Report model information and metrics to logger at the end of an epoch.", "parameters": [ "trainer" ], "return_type": "None", "decorators": [], "complexity_score": 5, "dependencies": [ "ultralytics.utils.LOGGER", "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "clearml", "clearml.Task", "re", "matplotlib.image", "matplotlib.pyplot", "clearml.binding.frameworks.pytorch_bind.PatchPyTorchModelIO", "clearml.binding.matplotlib_bind.PatchedMatplotlib", "ultralytics.utils.torch_utils.model_info_for_loggers" ], "chunk_id": "function_on_fit_epoch_end_3d64d955" }, { "content": "def on_val_end(validator) -> None:\n \"\"\"Log validation results including labels and predictions.\"\"\"\n if Task.current_task():\n # Log validation labels and predictions\n _log_debug_samples(sorted(validator.save_dir.glob(\"val*.jpg\")), \"Validation\")", "chunk_type": "function", "name": "on_val_end", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\clearml.py", "start_line": 116, "end_line": 120, "start_col": 0, "end_col": 85, "parent_name": null, "docstring": "Log validation results including labels and predictions.", "parameters": [ "validator" ], "return_type": "None", "decorators": [], "complexity_score": 2, "dependencies": [ "ultralytics.utils.LOGGER", "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "clearml", "clearml.Task", "re", "matplotlib.image", "matplotlib.pyplot", "clearml.binding.frameworks.pytorch_bind.PatchPyTorchModelIO", "clearml.binding.matplotlib_bind.PatchedMatplotlib", "ultralytics.utils.torch_utils.model_info_for_loggers" ], "chunk_id": "function_on_val_end_9a3d3651" }, { "content": "def on_train_end(trainer) -> None:\n \"\"\"Log final model and training results on training completion.\"\"\"\n if task := Task.current_task():\n # Log final results, confusion matrix and PR plots\n files = [\n \"results.png\",\n \"confusion_matrix.png\",\n \"confusion_matrix_normalized.png\",\n *(f\"{x}_curve.png\" for x in (\"F1\", \"PR\", \"P\", \"R\")),\n ]\n files = [(trainer.save_dir / f) for f in files if (trainer.save_dir / f).exists()] # filter existing files\n for f in files:\n _log_plot(title=f.stem, plot_path=f)\n # Report final metrics\n for k, v in trainer.validator.metrics.results_dict.items():\n task.get_logger().report_single_value(k, v)\n # Log the final model\n task.update_output_model(model_path=str(trainer.best), model_name=trainer.args.name, auto_delete_file=False)", "chunk_type": "function", "name": "on_train_end", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\clearml.py", "start_line": 123, "end_line": 140, "start_col": 0, "end_col": 116, "parent_name": null, "docstring": "Log final model and training results on training completion.", "parameters": [ "trainer" ], "return_type": "None", "decorators": [], "complexity_score": 6, "dependencies": [ "ultralytics.utils.LOGGER", "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "clearml", "clearml.Task", "re", "matplotlib.image", "matplotlib.pyplot", "clearml.binding.frameworks.pytorch_bind.PatchPyTorchModelIO", "clearml.binding.matplotlib_bind.PatchedMatplotlib", "ultralytics.utils.torch_utils.model_info_for_loggers" ], "chunk_id": "function_on_train_end_9539f420" }, { "content": "callbacks = (\n {\n \"on_pretrain_routine_start\": on_pretrain_routine_start,\n \"on_train_epoch_end\": on_train_epoch_end,\n \"on_fit_epoch_end\": on_fit_epoch_end,\n \"on_val_end\": on_val_end,\n \"on_train_end\": on_train_end,\n }\n if clearml\n else {}\n)", "chunk_type": "variable", "name": "callbacks", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\clearml.py", "start_line": 143, "end_line": 153, "start_col": 0, "end_col": 1, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_callbacks_03e51358" }, { "content": "from collections.abc import Callable", "chunk_type": "import", "name": "Callable", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\comet.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 36, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Callable_c4d6c13a" }, { "content": "from types import SimpleNamespace", "chunk_type": "import", "name": "SimpleNamespace", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\comet.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 33, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_SimpleNamespace_dda040c9" }, { "content": "from typing import Any, List, Optional", "chunk_type": "import", "name": "Any, List, Optional", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\comet.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 38, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Any, List, Optional_2b5fd40b" }, { "content": "import cv2", "chunk_type": "import", "name": "cv2", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\comet.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 10, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_cv2_93c76c4c" }, { "content": "import numpy as np", "chunk_type": "import", "name": "numpy", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\comet.py", "start_line": 8, "end_line": 8, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_numpy_0d19a4cc" }, { "content": "from ultralytics.utils import LOGGER, RANK, SETTINGS, TESTS_RUNNING, ops", "chunk_type": "import", "name": "LOGGER, RANK, SETTINGS, TESTS_RUNNING, ops", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\comet.py", "start_line": 10, "end_line": 10, "start_col": 0, "end_col": 72, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_LOGGER, RANK, SETTINGS, TESTS_RUNNING, ops_07aeb2f7" }, { "content": "from ultralytics.utils.metrics import ClassifyMetrics, DetMetrics, OBBMetrics, PoseMetrics, SegmentMetrics", "chunk_type": "import", "name": "ClassifyMetrics, DetMetrics, OBBMetrics, PoseMetrics, SegmentMetrics", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\comet.py", "start_line": 11, "end_line": 11, "start_col": 0, "end_col": 106, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_ClassifyMetrics, DetMetrics, OBBMetrics, PoseMetrics, SegmentMetrics_11aefd35" }, { "content": "def _get_comet_mode() -> str:\n \"\"\"Return the Comet mode from environment variables, defaulting to 'online'.\"\"\"\n comet_mode = os.getenv(\"COMET_MODE\")\n if comet_mode is not None:\n LOGGER.warning(\n \"The COMET_MODE environment variable is deprecated. \"\n \"Please use COMET_START_ONLINE to set the Comet experiment mode. \"\n \"To start an offline Comet experiment, use 'export COMET_START_ONLINE=0'. \"\n \"If COMET_START_ONLINE is not set or is set to '1', an online Comet experiment will be created.\"\n )\n return comet_mode\n\n return \"online\"", "chunk_type": "function", "name": "_get_comet_mode", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\comet.py", "start_line": 39, "end_line": 51, "start_col": 0, "end_col": 19, "parent_name": null, "docstring": "Return the Comet mode from environment variables, defaulting to 'online'.", "parameters": [], "return_type": "str", "decorators": [], "complexity_score": 2, "dependencies": [ "collections.abc.Callable", "types.SimpleNamespace", "typing.Any", "typing.List", "typing.Optional", "cv2", "numpy", "ultralytics.utils.LOGGER", "ultralytics.utils.RANK", "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.ops", "ultralytics.utils.metrics.ClassifyMetrics", "ultralytics.utils.metrics.DetMetrics", "ultralytics.utils.metrics.OBBMetrics", "ultralytics.utils.metrics.PoseMetrics", "ultralytics.utils.metrics.SegmentMetrics", "comet_ml", "os", "pathlib.Path", "faster_coco_eval.core.mask.decode", "ultralytics.utils.torch_utils.model_info_for_loggers" ], "chunk_id": "function__get_comet_mode_98d487f2" }, { "content": "def _get_comet_model_name() -> str:\n \"\"\"Return the Comet model name from environment variable or default to 'Ultralytics'.\"\"\"\n return os.getenv(\"COMET_MODEL_NAME\", \"Ultralytics\")", "chunk_type": "function", "name": "_get_comet_model_name", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\comet.py", "start_line": 54, "end_line": 56, "start_col": 0, "end_col": 55, "parent_name": null, "docstring": "Return the Comet model name from environment variable or default to 'Ultralytics'.", "parameters": [], "return_type": "str", "decorators": [], "complexity_score": 1, "dependencies": [ "collections.abc.Callable", "types.SimpleNamespace", "typing.Any", "typing.List", "typing.Optional", "cv2", "numpy", "ultralytics.utils.LOGGER", "ultralytics.utils.RANK", "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.ops", "ultralytics.utils.metrics.ClassifyMetrics", "ultralytics.utils.metrics.DetMetrics", "ultralytics.utils.metrics.OBBMetrics", "ultralytics.utils.metrics.PoseMetrics", "ultralytics.utils.metrics.SegmentMetrics", "comet_ml", "os", "pathlib.Path", "faster_coco_eval.core.mask.decode", "ultralytics.utils.torch_utils.model_info_for_loggers" ], "chunk_id": "function__get_comet_model_name_1b12a748" }, { "content": "def _get_eval_batch_logging_interval() -> int:\n \"\"\"Get the evaluation batch logging interval from environment variable or use default value 1.\"\"\"\n return int(os.getenv(\"COMET_EVAL_BATCH_LOGGING_INTERVAL\", 1))", "chunk_type": "function", "name": "_get_eval_batch_logging_interval", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\comet.py", "start_line": 59, "end_line": 61, "start_col": 0, "end_col": 65, "parent_name": null, "docstring": "Get the evaluation batch logging interval from environment variable or use default value 1.", "parameters": [], "return_type": "int", "decorators": [], "complexity_score": 1, "dependencies": [ "collections.abc.Callable", "types.SimpleNamespace", "typing.Any", "typing.List", "typing.Optional", "cv2", "numpy", "ultralytics.utils.LOGGER", "ultralytics.utils.RANK", "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.ops", "ultralytics.utils.metrics.ClassifyMetrics", "ultralytics.utils.metrics.DetMetrics", "ultralytics.utils.metrics.OBBMetrics", "ultralytics.utils.metrics.PoseMetrics", "ultralytics.utils.metrics.SegmentMetrics", "comet_ml", "os", "pathlib.Path", "faster_coco_eval.core.mask.decode", "ultralytics.utils.torch_utils.model_info_for_loggers" ], "chunk_id": "function__get_eval_batch_logging_interval_67f6c443" }, { "content": "def _get_max_image_predictions_to_log() -> int:\n \"\"\"Get the maximum number of image predictions to log from environment variables.\"\"\"\n return int(os.getenv(\"COMET_MAX_IMAGE_PREDICTIONS\", 100))", "chunk_type": "function", "name": "_get_max_image_predictions_to_log", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\comet.py", "start_line": 64, "end_line": 66, "start_col": 0, "end_col": 61, "parent_name": null, "docstring": "Get the maximum number of image predictions to log from environment variables.", "parameters": [], "return_type": "int", "decorators": [], "complexity_score": 1, "dependencies": [ "collections.abc.Callable", "types.SimpleNamespace", "typing.Any", "typing.List", "typing.Optional", "cv2", "numpy", "ultralytics.utils.LOGGER", "ultralytics.utils.RANK", "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.ops", "ultralytics.utils.metrics.ClassifyMetrics", "ultralytics.utils.metrics.DetMetrics", "ultralytics.utils.metrics.OBBMetrics", "ultralytics.utils.metrics.PoseMetrics", "ultralytics.utils.metrics.SegmentMetrics", "comet_ml", "os", "pathlib.Path", "faster_coco_eval.core.mask.decode", "ultralytics.utils.torch_utils.model_info_for_loggers" ], "chunk_id": "function__get_max_image_predictions_to_log_ded60425" }, { "content": "def _scale_confidence_score(score: float) -> float:\n \"\"\"Scale the confidence score by a factor specified in environment variable.\"\"\"\n scale = float(os.getenv(\"COMET_MAX_CONFIDENCE_SCORE\", 100.0))\n return score * scale", "chunk_type": "function", "name": "_scale_confidence_score", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\comet.py", "start_line": 69, "end_line": 72, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": "Scale the confidence score by a factor specified in environment variable.", "parameters": [ "score: float" ], "return_type": "float", "decorators": [], "complexity_score": 1, "dependencies": [ "collections.abc.Callable", "types.SimpleNamespace", "typing.Any", "typing.List", "typing.Optional", "cv2", "numpy", "ultralytics.utils.LOGGER", "ultralytics.utils.RANK", "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.ops", "ultralytics.utils.metrics.ClassifyMetrics", "ultralytics.utils.metrics.DetMetrics", "ultralytics.utils.metrics.OBBMetrics", "ultralytics.utils.metrics.PoseMetrics", "ultralytics.utils.metrics.SegmentMetrics", "comet_ml", "os", "pathlib.Path", "faster_coco_eval.core.mask.decode", "ultralytics.utils.torch_utils.model_info_for_loggers" ], "chunk_id": "function__scale_confidence_score_bee04db0" }, { "content": "def _should_log_confusion_matrix() -> bool:\n \"\"\"Determine if the confusion matrix should be logged based on environment variable settings.\"\"\"\n return os.getenv(\"COMET_EVAL_LOG_CONFUSION_MATRIX\", \"false\").lower() == \"true\"", "chunk_type": "function", "name": "_should_log_confusion_matrix", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\comet.py", "start_line": 75, "end_line": 77, "start_col": 0, "end_col": 82, "parent_name": null, "docstring": "Determine if the confusion matrix should be logged based on environment variable settings.", "parameters": [], "return_type": "bool", "decorators": [], "complexity_score": 1, "dependencies": [ "collections.abc.Callable", "types.SimpleNamespace", "typing.Any", "typing.List", "typing.Optional", "cv2", "numpy", "ultralytics.utils.LOGGER", "ultralytics.utils.RANK", "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.ops", "ultralytics.utils.metrics.ClassifyMetrics", "ultralytics.utils.metrics.DetMetrics", "ultralytics.utils.metrics.OBBMetrics", "ultralytics.utils.metrics.PoseMetrics", "ultralytics.utils.metrics.SegmentMetrics", "comet_ml", "os", "pathlib.Path", "faster_coco_eval.core.mask.decode", "ultralytics.utils.torch_utils.model_info_for_loggers" ], "chunk_id": "function__should_log_confusion_matrix_88937404" }, { "content": "def _should_log_image_predictions() -> bool:\n \"\"\"Determine whether to log image predictions based on environment variable.\"\"\"\n return os.getenv(\"COMET_EVAL_LOG_IMAGE_PREDICTIONS\", \"true\").lower() == \"true\"", "chunk_type": "function", "name": "_should_log_image_predictions", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\comet.py", "start_line": 80, "end_line": 82, "start_col": 0, "end_col": 82, "parent_name": null, "docstring": "Determine whether to log image predictions based on environment variable.", "parameters": [], "return_type": "bool", "decorators": [], "complexity_score": 1, "dependencies": [ "collections.abc.Callable", "types.SimpleNamespace", "typing.Any", "typing.List", "typing.Optional", "cv2", "numpy", "ultralytics.utils.LOGGER", "ultralytics.utils.RANK", "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.ops", "ultralytics.utils.metrics.ClassifyMetrics", "ultralytics.utils.metrics.DetMetrics", "ultralytics.utils.metrics.OBBMetrics", "ultralytics.utils.metrics.PoseMetrics", "ultralytics.utils.metrics.SegmentMetrics", "comet_ml", "os", "pathlib.Path", "faster_coco_eval.core.mask.decode", "ultralytics.utils.torch_utils.model_info_for_loggers" ], "chunk_id": "function__should_log_image_predictions_2ec21984" }, { "content": "def _resume_or_create_experiment(args: SimpleNamespace) -> None:\n \"\"\"\n Resume CometML experiment or create a new experiment based on args.\n\n Ensures that the experiment object is only created in a single process during distributed training.\n\n Args:\n args (SimpleNamespace): Training arguments containing project configuration and other parameters.\n \"\"\"\n if RANK not in {-1, 0}:\n return\n\n # Set environment variable (if not set by the user) to configure the Comet experiment's online mode under the hood.\n # IF COMET_START_ONLINE is set by the user it will override COMET_MODE value.\n if os.getenv(\"COMET_START_ONLINE\") is None:\n comet_mode = _get_comet_mode()\n os.environ[\"COMET_START_ONLINE\"] = \"1\" if comet_mode != \"offline\" else \"0\"\n\n try:\n _project_name = os.getenv(\"COMET_PROJECT_NAME\", args.project)\n experiment = comet_ml.start(project_name=_project_name)\n experiment.log_parameters(vars(args))\n experiment.log_others(\n {\n \"eval_batch_logging_interval\": _get_eval_batch_logging_interval(),\n \"log_confusion_matrix_on_eval\": _should_log_confusion_matrix(),\n \"log_image_predictions\": _should_log_image_predictions(),\n \"max_image_predictions\": _get_max_image_predictions_to_log(),\n }\n )\n experiment.log_other(\"Created from\", \"ultralytics\")\n\n except Exception as e:\n LOGGER.warning(f\"Comet installed but not initialized correctly, not logging this run. {e}\")", "chunk_type": "function", "name": "_resume_or_create_experiment", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\comet.py", "start_line": 85, "end_line": 118, "start_col": 0, "end_col": 99, "parent_name": null, "docstring": "Resume CometML experiment or create a new experiment based on args.\n\nEnsures that the experiment object is only created in a single process during distributed training.\n\nArgs:\n args (SimpleNamespace): Training arguments containing project configuration and other parameters.", "parameters": [ "args: SimpleNamespace" ], "return_type": "None", "decorators": [], "complexity_score": 4, "dependencies": [ "collections.abc.Callable", "types.SimpleNamespace", "typing.Any", "typing.List", "typing.Optional", "cv2", "numpy", "ultralytics.utils.LOGGER", "ultralytics.utils.RANK", "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.ops", "ultralytics.utils.metrics.ClassifyMetrics", "ultralytics.utils.metrics.DetMetrics", "ultralytics.utils.metrics.OBBMetrics", "ultralytics.utils.metrics.PoseMetrics", "ultralytics.utils.metrics.SegmentMetrics", "comet_ml", "os", "pathlib.Path", "faster_coco_eval.core.mask.decode", "ultralytics.utils.torch_utils.model_info_for_loggers" ], "chunk_id": "function__resume_or_create_experiment_bf13f87e" }, { "content": "def _fetch_trainer_metadata(trainer) -> dict:\n \"\"\"\n Return metadata for YOLO training including epoch and asset saving status.\n\n Args:\n trainer (ultralytics.engine.trainer.BaseTrainer): The YOLO trainer object containing training state and config.\n\n Returns:\n (dict): Dictionary containing current epoch, step, save assets flag, and final epoch flag.\n \"\"\"\n curr_epoch = trainer.epoch + 1\n\n train_num_steps_per_epoch = len(trainer.train_loader.dataset) // trainer.batch_size\n curr_step = curr_epoch * train_num_steps_per_epoch\n final_epoch = curr_epoch == trainer.epochs\n\n save = trainer.args.save\n save_period = trainer.args.save_period\n save_interval = curr_epoch % save_period == 0\n save_assets = save and save_period > 0 and save_interval and not final_epoch\n\n return dict(curr_epoch=curr_epoch, curr_step=curr_step, save_assets=save_assets, final_epoch=final_epoch)", "chunk_type": "function", "name": "_fetch_trainer_metadata", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\comet.py", "start_line": 121, "end_line": 142, "start_col": 0, "end_col": 109, "parent_name": null, "docstring": "Return metadata for YOLO training including epoch and asset saving status.\n\nArgs:\n trainer (ultralytics.engine.trainer.BaseTrainer): The YOLO trainer object containing training state and config.\n\nReturns:\n (dict): Dictionary containing current epoch, step, save assets flag, and final epoch flag.", "parameters": [ "trainer" ], "return_type": "dict", "decorators": [], "complexity_score": 1, "dependencies": [ "collections.abc.Callable", "types.SimpleNamespace", "typing.Any", "typing.List", "typing.Optional", "cv2", "numpy", "ultralytics.utils.LOGGER", "ultralytics.utils.RANK", "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.ops", "ultralytics.utils.metrics.ClassifyMetrics", "ultralytics.utils.metrics.DetMetrics", "ultralytics.utils.metrics.OBBMetrics", "ultralytics.utils.metrics.PoseMetrics", "ultralytics.utils.metrics.SegmentMetrics", "comet_ml", "os", "pathlib.Path", "faster_coco_eval.core.mask.decode", "ultralytics.utils.torch_utils.model_info_for_loggers" ], "chunk_id": "function__fetch_trainer_metadata_010b2d25" }, { "content": "def _scale_bounding_box_to_original_image_shape(\n box, resized_image_shape, original_image_shape, ratio_pad\n) -> List[float]:\n \"\"\"\n Scale bounding box from resized image coordinates to original image coordinates.\n\n YOLO resizes images during training and the label values are normalized based on this resized shape.\n This function rescales the bounding box labels to the original image shape.\n\n Args:\n box (torch.Tensor): Bounding box in normalized xywh format.\n resized_image_shape (tuple): Shape of the resized image (height, width).\n original_image_shape (tuple): Shape of the original image (height, width).\n ratio_pad (tuple): Ratio and padding information for scaling.\n\n Returns:\n (List[float]): Scaled bounding box coordinates in xywh format with top-left corner adjustment.\n \"\"\"\n resized_image_height, resized_image_width = resized_image_shape\n\n # Convert normalized xywh format predictions to xyxy in resized scale format\n box = ops.xywhn2xyxy(box, h=resized_image_height, w=resized_image_width)\n # Scale box predictions from resized image scale back to original image scale\n box = ops.scale_boxes(resized_image_shape, box, original_image_shape, ratio_pad)\n # Convert bounding box format from xyxy to xywh for Comet logging\n box = ops.xyxy2xywh(box)\n # Adjust xy center to correspond top-left corner\n box[:2] -= box[2:] / 2\n box = box.tolist()\n\n return box", "chunk_type": "function", "name": "_scale_bounding_box_to_original_image_shape", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\comet.py", "start_line": 145, "end_line": 175, "start_col": 0, "end_col": 14, "parent_name": null, "docstring": "Scale bounding box from resized image coordinates to original image coordinates.\n\nYOLO resizes images during training and the label values are normalized based on this resized shape.\nThis function rescales the bounding box labels to the original image shape.\n\nArgs:\n box (torch.Tensor): Bounding box in normalized xywh format.\n resized_image_shape (tuple): Shape of the resized image (height, width).\n original_image_shape (tuple): Shape of the original image (height, width).\n ratio_pad (tuple): Ratio and padding information for scaling.\n\nReturns:\n (List[float]): Scaled bounding box coordinates in xywh format with top-left corner adjustment.", "parameters": [ "box", "resized_image_shape", "original_image_shape", "ratio_pad" ], "return_type": "List[float]", "decorators": [], "complexity_score": 1, "dependencies": [ "collections.abc.Callable", "types.SimpleNamespace", "typing.Any", "typing.List", "typing.Optional", "cv2", "numpy", "ultralytics.utils.LOGGER", "ultralytics.utils.RANK", "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.ops", "ultralytics.utils.metrics.ClassifyMetrics", "ultralytics.utils.metrics.DetMetrics", "ultralytics.utils.metrics.OBBMetrics", "ultralytics.utils.metrics.PoseMetrics", "ultralytics.utils.metrics.SegmentMetrics", "comet_ml", "os", "pathlib.Path", "faster_coco_eval.core.mask.decode", "ultralytics.utils.torch_utils.model_info_for_loggers" ], "chunk_id": "function__scale_bounding_box_to_original_image_shape_c86c2d8f" }, { "content": "def _format_ground_truth_annotations_for_detection(img_idx, image_path, batch, class_name_map=None) -> Optional[dict]:\n \"\"\"\n Format ground truth annotations for object detection.\n\n This function processes ground truth annotations from a batch of images for object detection tasks. It extracts\n bounding boxes, class labels, and other metadata for a specific image in the batch, and formats them for\n visualization or evaluation.\n\n Args:\n img_idx (int): Index of the image in the batch to process.\n image_path (str | Path): Path to the image file.\n batch (dict): Batch dictionary containing detection data with keys:\n - 'batch_idx': Tensor of batch indices\n - 'bboxes': Tensor of bounding boxes in normalized xywh format\n - 'cls': Tensor of class labels\n - 'ori_shape': Original image shapes\n - 'resized_shape': Resized image shapes\n - 'ratio_pad': Ratio and padding information\n class_name_map (dict, optional): Mapping from class indices to class names.\n\n Returns:\n (dict | None): Formatted ground truth annotations with the following structure:\n - 'boxes': List of box coordinates [x, y, width, height]\n - 'label': Label string with format \"gt_{class_name}\"\n - 'score': Confidence score (always 1.0, scaled by _scale_confidence_score)\n Returns None if no bounding boxes are found for the image.\n \"\"\"\n indices = batch[\"batch_idx\"] == img_idx\n bboxes = batch[\"bboxes\"][indices]\n if len(bboxes) == 0:\n LOGGER.debug(f\"Comet Image: {image_path} has no bounding boxes labels\")\n return None\n\n cls_labels = batch[\"cls\"][indices].squeeze(1).tolist()\n if class_name_map:\n cls_labels = [str(class_name_map[label]) for label in cls_labels]\n\n original_image_shape = batch[\"ori_shape\"][img_idx]\n resized_image_shape = batch[\"resized_shape\"][img_idx]\n ratio_pad = batch[\"ratio_pad\"][img_idx]\n\n data = []\n for box, label in zip(bboxes, cls_labels):\n box = _scale_bounding_box_to_original_image_shape(box, resized_image_shape, original_image_shape, ratio_pad)\n data.append(\n {\n \"boxes\": [box],\n \"label\": f\"gt_{label}\",\n \"score\": _scale_confidence_score(1.0),\n }\n )\n\n return {\"name\": \"ground_truth\", \"data\": data}", "chunk_type": "function", "name": "_format_ground_truth_annotations_for_detection", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\comet.py", "start_line": 178, "end_line": 230, "start_col": 0, "end_col": 49, "parent_name": null, "docstring": "Format ground truth annotations for object detection.\n\nThis function processes ground truth annotations from a batch of images for object detection tasks. It extracts\nbounding boxes, class labels, and other metadata for a specific image in the batch, and formats them for\nvisualization or evaluation.\n\nArgs:\n img_idx (int): Index of the image in the batch to process.\n image_path (str | Path): Path to the image file.\n batch (dict): Batch dictionary containing detection data with keys:\n - 'batch_idx': Tensor of batch indices\n - 'bboxes': Tensor of bounding boxes in normalized xywh format\n - 'cls': Tensor of class labels\n - 'ori_shape': Original image shapes\n - 'resized_shape': Resized image shapes\n - 'ratio_pad': Ratio and padding information\n class_name_map (dict, optional): Mapping from class indices to class names.\n\nReturns:\n (dict | None): Formatted ground truth annotations with the following structure:\n - 'boxes': List of box coordinates [x, y, width, height]\n - 'label': Label string with format \"gt_{class_name}\"\n - 'score': Confidence score (always 1.0, scaled by _scale_confidence_score)\n Returns None if no bounding boxes are found for the image.", "parameters": [ "img_idx", "image_path", "batch", "class_name_map" ], "return_type": "Optional[dict]", "decorators": [], "complexity_score": 5, "dependencies": [ "collections.abc.Callable", "types.SimpleNamespace", "typing.Any", "typing.List", "typing.Optional", "cv2", "numpy", "ultralytics.utils.LOGGER", "ultralytics.utils.RANK", "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.ops", "ultralytics.utils.metrics.ClassifyMetrics", "ultralytics.utils.metrics.DetMetrics", "ultralytics.utils.metrics.OBBMetrics", "ultralytics.utils.metrics.PoseMetrics", "ultralytics.utils.metrics.SegmentMetrics", "comet_ml", "os", "pathlib.Path", "faster_coco_eval.core.mask.decode", "ultralytics.utils.torch_utils.model_info_for_loggers" ], "chunk_id": "function__format_ground_truth_annotations_for_detection_ea29e256" }, { "content": "def _format_prediction_annotations(image_path, metadata, class_label_map=None, class_map=None) -> Optional[dict]:\n \"\"\"\n Format YOLO predictions for object detection visualization.\n\n Args:\n image_path (Path): Path to the image file.\n metadata (dict): Prediction metadata containing bounding boxes and class information.\n class_label_map (dict, optional): Mapping from class indices to class names.\n class_map (dict, optional): Additional class mapping for label conversion.\n\n Returns:\n (dict | None): Formatted prediction annotations or None if no predictions exist.\n \"\"\"\n stem = image_path.stem\n image_id = int(stem) if stem.isnumeric() else stem\n\n predictions = metadata.get(image_id)\n if not predictions:\n LOGGER.debug(f\"Comet Image: {image_path} has no bounding boxes predictions\")\n return None\n\n # apply the mapping that was used to map the predicted classes when the JSON was created\n if class_label_map and class_map:\n class_label_map = {class_map[k]: v for k, v in class_label_map.items()}\n try:\n # import pycotools utilities to decompress annotations for various tasks, e.g. segmentation\n from faster_coco_eval.core.mask import decode # noqa\n except ImportError:\n decode = None\n\n data = []\n for prediction in predictions:\n boxes = prediction[\"bbox\"]\n score = _scale_confidence_score(prediction[\"score\"])\n cls_label = prediction[\"category_id\"]\n if class_label_map:\n cls_label = str(class_label_map[cls_label])\n\n annotation_data = {\"boxes\": [boxes], \"label\": cls_label, \"score\": score}\n\n if decode is not None:\n # do segmentation processing only if we are able to decode it\n segments = prediction.get(\"segmentation\", None)\n if segments is not None:\n segments = _extract_segmentation_annotation(segments, decode)\n if segments is not None:\n annotation_data[\"points\"] = segments\n\n data.append(annotation_data)\n\n return {\"name\": \"prediction\", \"data\": data}", "chunk_type": "function", "name": "_format_prediction_annotations", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\comet.py", "start_line": 233, "end_line": 283, "start_col": 0, "end_col": 47, "parent_name": null, "docstring": "Format YOLO predictions for object detection visualization.\n\nArgs:\n image_path (Path): Path to the image file.\n metadata (dict): Prediction metadata containing bounding boxes and class information.\n class_label_map (dict, optional): Mapping from class indices to class names.\n class_map (dict, optional): Additional class mapping for label conversion.\n\nReturns:\n (dict | None): Formatted prediction annotations or None if no predictions exist.", "parameters": [ "image_path", "metadata", "class_label_map", "class_map" ], "return_type": "Optional[dict]", "decorators": [], "complexity_score": 10, "dependencies": [ "collections.abc.Callable", "types.SimpleNamespace", "typing.Any", "typing.List", "typing.Optional", "cv2", "numpy", "ultralytics.utils.LOGGER", "ultralytics.utils.RANK", "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.ops", "ultralytics.utils.metrics.ClassifyMetrics", "ultralytics.utils.metrics.DetMetrics", "ultralytics.utils.metrics.OBBMetrics", "ultralytics.utils.metrics.PoseMetrics", "ultralytics.utils.metrics.SegmentMetrics", "comet_ml", "os", "pathlib.Path", "faster_coco_eval.core.mask.decode", "ultralytics.utils.torch_utils.model_info_for_loggers" ], "chunk_id": "function__format_prediction_annotations_f01ea5f2" }, { "content": "def _extract_segmentation_annotation(segmentation_raw: str, decode: Callable) -> Optional[List[List[Any]]]:\n \"\"\"\n Extract segmentation annotation from compressed segmentations as list of polygons.\n\n Args:\n segmentation_raw (str): Raw segmentation data in compressed format.\n decode (Callable): Function to decode the compressed segmentation data.\n\n Returns:\n (List[List[Any]] | None): List of polygon points or None if extraction fails.\n \"\"\"\n try:\n mask = decode(segmentation_raw)\n contours, _ = cv2.findContours(mask, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)\n annotations = [np.array(polygon).squeeze() for polygon in contours if len(polygon) >= 3]\n return [annotation.ravel().tolist() for annotation in annotations]\n except Exception as e:\n LOGGER.warning(f\"Comet Failed to extract segmentation annotation: {e}\")\n return None", "chunk_type": "function", "name": "_extract_segmentation_annotation", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\comet.py", "start_line": 286, "end_line": 304, "start_col": 0, "end_col": 15, "parent_name": null, "docstring": "Extract segmentation annotation from compressed segmentations as list of polygons.\n\nArgs:\n segmentation_raw (str): Raw segmentation data in compressed format.\n decode (Callable): Function to decode the compressed segmentation data.\n\nReturns:\n (List[List[Any]] | None): List of polygon points or None if extraction fails.", "parameters": [ "segmentation_raw: str", "decode: Callable" ], "return_type": "Optional[List[List[Any]]]", "decorators": [], "complexity_score": 4, "dependencies": [ "collections.abc.Callable", "types.SimpleNamespace", "typing.Any", "typing.List", "typing.Optional", "cv2", "numpy", "ultralytics.utils.LOGGER", "ultralytics.utils.RANK", "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.ops", "ultralytics.utils.metrics.ClassifyMetrics", "ultralytics.utils.metrics.DetMetrics", "ultralytics.utils.metrics.OBBMetrics", "ultralytics.utils.metrics.PoseMetrics", "ultralytics.utils.metrics.SegmentMetrics", "comet_ml", "os", "pathlib.Path", "faster_coco_eval.core.mask.decode", "ultralytics.utils.torch_utils.model_info_for_loggers" ], "chunk_id": "function__extract_segmentation_annotation_62de6e1a" }, { "content": "def _fetch_annotations(\n img_idx, image_path, batch, prediction_metadata_map, class_label_map, class_map\n) -> Optional[List]:\n \"\"\"\n Join the ground truth and prediction annotations if they exist.\n\n Args:\n img_idx (int): Index of the image in the batch.\n image_path (Path): Path to the image file.\n batch (dict): Batch data containing ground truth annotations.\n prediction_metadata_map (dict): Map of prediction metadata by image ID.\n class_label_map (dict): Mapping from class indices to class names.\n class_map (dict): Additional class mapping for label conversion.\n\n Returns:\n (List | None): List of annotation dictionaries or None if no annotations exist.\n \"\"\"\n ground_truth_annotations = _format_ground_truth_annotations_for_detection(\n img_idx, image_path, batch, class_label_map\n )\n prediction_annotations = _format_prediction_annotations(\n image_path, prediction_metadata_map, class_label_map, class_map\n )\n\n annotations = [\n annotation for annotation in [ground_truth_annotations, prediction_annotations] if annotation is not None\n ]\n return [annotations] if annotations else None", "chunk_type": "function", "name": "_fetch_annotations", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\comet.py", "start_line": 307, "end_line": 334, "start_col": 0, "end_col": 49, "parent_name": null, "docstring": "Join the ground truth and prediction annotations if they exist.\n\nArgs:\n img_idx (int): Index of the image in the batch.\n image_path (Path): Path to the image file.\n batch (dict): Batch data containing ground truth annotations.\n prediction_metadata_map (dict): Map of prediction metadata by image ID.\n class_label_map (dict): Mapping from class indices to class names.\n class_map (dict): Additional class mapping for label conversion.\n\nReturns:\n (List | None): List of annotation dictionaries or None if no annotations exist.", "parameters": [ "img_idx", "image_path", "batch", "prediction_metadata_map", "class_label_map", "class_map" ], "return_type": "Optional[List]", "decorators": [], "complexity_score": 2, "dependencies": [ "collections.abc.Callable", "types.SimpleNamespace", "typing.Any", "typing.List", "typing.Optional", "cv2", "numpy", "ultralytics.utils.LOGGER", "ultralytics.utils.RANK", "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.ops", "ultralytics.utils.metrics.ClassifyMetrics", "ultralytics.utils.metrics.DetMetrics", "ultralytics.utils.metrics.OBBMetrics", "ultralytics.utils.metrics.PoseMetrics", "ultralytics.utils.metrics.SegmentMetrics", "comet_ml", "os", "pathlib.Path", "faster_coco_eval.core.mask.decode", "ultralytics.utils.torch_utils.model_info_for_loggers" ], "chunk_id": "function__fetch_annotations_979a6a1f" }, { "content": "def _create_prediction_metadata_map(model_predictions) -> dict:\n \"\"\"Create metadata map for model predictions by grouping them based on image ID.\"\"\"\n pred_metadata_map = {}\n for prediction in model_predictions:\n pred_metadata_map.setdefault(prediction[\"image_id\"], [])\n pred_metadata_map[prediction[\"image_id\"]].append(prediction)\n\n return pred_metadata_map", "chunk_type": "function", "name": "_create_prediction_metadata_map", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\comet.py", "start_line": 337, "end_line": 344, "start_col": 0, "end_col": 28, "parent_name": null, "docstring": "Create metadata map for model predictions by grouping them based on image ID.", "parameters": [ "model_predictions" ], "return_type": "dict", "decorators": [], "complexity_score": 2, "dependencies": [ "collections.abc.Callable", "types.SimpleNamespace", "typing.Any", "typing.List", "typing.Optional", "cv2", "numpy", "ultralytics.utils.LOGGER", "ultralytics.utils.RANK", "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.ops", "ultralytics.utils.metrics.ClassifyMetrics", "ultralytics.utils.metrics.DetMetrics", "ultralytics.utils.metrics.OBBMetrics", "ultralytics.utils.metrics.PoseMetrics", "ultralytics.utils.metrics.SegmentMetrics", "comet_ml", "os", "pathlib.Path", "faster_coco_eval.core.mask.decode", "ultralytics.utils.torch_utils.model_info_for_loggers" ], "chunk_id": "function__create_prediction_metadata_map_287d87d4" }, { "content": "def _log_confusion_matrix(experiment, trainer, curr_step, curr_epoch) -> None:\n \"\"\"Log the confusion matrix to Comet experiment.\"\"\"\n conf_mat = trainer.validator.confusion_matrix.matrix\n names = list(trainer.data[\"names\"].values()) + [\"background\"]\n experiment.log_confusion_matrix(\n matrix=conf_mat, labels=names, max_categories=len(names), epoch=curr_epoch, step=curr_step\n )", "chunk_type": "function", "name": "_log_confusion_matrix", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\comet.py", "start_line": 347, "end_line": 353, "start_col": 0, "end_col": 5, "parent_name": null, "docstring": "Log the confusion matrix to Comet experiment.", "parameters": [ "experiment", "trainer", "curr_step", "curr_epoch" ], "return_type": "None", "decorators": [], "complexity_score": 1, "dependencies": [ "collections.abc.Callable", "types.SimpleNamespace", "typing.Any", "typing.List", "typing.Optional", "cv2", "numpy", "ultralytics.utils.LOGGER", "ultralytics.utils.RANK", "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.ops", "ultralytics.utils.metrics.ClassifyMetrics", "ultralytics.utils.metrics.DetMetrics", "ultralytics.utils.metrics.OBBMetrics", "ultralytics.utils.metrics.PoseMetrics", "ultralytics.utils.metrics.SegmentMetrics", "comet_ml", "os", "pathlib.Path", "faster_coco_eval.core.mask.decode", "ultralytics.utils.torch_utils.model_info_for_loggers" ], "chunk_id": "function__log_confusion_matrix_af670c19" }, { "content": "def _log_images(experiment, image_paths, curr_step, annotations=None) -> None:\n \"\"\"\n Log images to the experiment with optional annotations.\n\n This function logs images to a Comet ML experiment, optionally including annotation data for visualization\n such as bounding boxes or segmentation masks.\n\n Args:\n experiment (comet_ml.Experiment): The Comet ML experiment to log images to.\n image_paths (List[Path]): List of paths to images that will be logged.\n curr_step (int): Current training step/iteration for tracking in the experiment timeline.\n annotations (List[List[dict]], optional): Nested list of annotation dictionaries for each image. Each\n annotation contains visualization data like bounding boxes, labels, and confidence scores.\n \"\"\"\n if annotations:\n for image_path, annotation in zip(image_paths, annotations):\n experiment.log_image(image_path, name=image_path.stem, step=curr_step, annotations=annotation)\n\n else:\n for image_path in image_paths:\n experiment.log_image(image_path, name=image_path.stem, step=curr_step)", "chunk_type": "function", "name": "_log_images", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\comet.py", "start_line": 356, "end_line": 376, "start_col": 0, "end_col": 82, "parent_name": null, "docstring": "Log images to the experiment with optional annotations.\n\nThis function logs images to a Comet ML experiment, optionally including annotation data for visualization\nsuch as bounding boxes or segmentation masks.\n\nArgs:\n experiment (comet_ml.Experiment): The Comet ML experiment to log images to.\n image_paths (List[Path]): List of paths to images that will be logged.\n curr_step (int): Current training step/iteration for tracking in the experiment timeline.\n annotations (List[List[dict]], optional): Nested list of annotation dictionaries for each image. Each\n annotation contains visualization data like bounding boxes, labels, and confidence scores.", "parameters": [ "experiment", "image_paths", "curr_step", "annotations" ], "return_type": "None", "decorators": [], "complexity_score": 4, "dependencies": [ "collections.abc.Callable", "types.SimpleNamespace", "typing.Any", "typing.List", "typing.Optional", "cv2", "numpy", "ultralytics.utils.LOGGER", "ultralytics.utils.RANK", "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.ops", "ultralytics.utils.metrics.ClassifyMetrics", "ultralytics.utils.metrics.DetMetrics", "ultralytics.utils.metrics.OBBMetrics", "ultralytics.utils.metrics.PoseMetrics", "ultralytics.utils.metrics.SegmentMetrics", "comet_ml", "os", "pathlib.Path", "faster_coco_eval.core.mask.decode", "ultralytics.utils.torch_utils.model_info_for_loggers" ], "chunk_id": "function__log_images_630cee76" }, { "content": "def _log_image_predictions(experiment, validator, curr_step) -> None:\n \"\"\"\n Log predicted boxes for a single image during training.\n\n This function logs image predictions to a Comet ML experiment during model validation. It processes\n validation data and formats both ground truth and prediction annotations for visualization in the Comet\n dashboard. The function respects configured limits on the number of images to log.\n\n Args:\n experiment (comet_ml.Experiment): The Comet ML experiment to log to.\n validator (BaseValidator): The validator instance containing validation data and predictions.\n curr_step (int): The current training step for logging timeline.\n\n Notes:\n This function uses global state to track the number of logged predictions across calls.\n It only logs predictions for supported tasks defined in COMET_SUPPORTED_TASKS.\n The number of logged images is limited by the COMET_MAX_IMAGE_PREDICTIONS environment variable.\n \"\"\"\n global _comet_image_prediction_count\n\n task = validator.args.task\n if task not in COMET_SUPPORTED_TASKS:\n return\n\n jdict = validator.jdict\n if not jdict:\n return\n\n predictions_metadata_map = _create_prediction_metadata_map(jdict)\n dataloader = validator.dataloader\n class_label_map = validator.names\n class_map = getattr(validator, \"class_map\", None)\n\n batch_logging_interval = _get_eval_batch_logging_interval()\n max_image_predictions = _get_max_image_predictions_to_log()\n\n for batch_idx, batch in enumerate(dataloader):\n if (batch_idx + 1) % batch_logging_interval != 0:\n continue\n\n image_paths = batch[\"im_file\"]\n for img_idx, image_path in enumerate(image_paths):\n if _comet_image_prediction_count >= max_image_predictions:\n return\n\n image_path = Path(image_path)\n annotations = _fetch_annotations(\n img_idx,\n image_path,\n batch,\n predictions_metadata_map,\n class_label_map,\n class_map=class_map,\n )\n _log_images(\n experiment,\n [image_path],\n curr_step,\n annotations=annotations,\n )\n _comet_image_prediction_count += 1", "chunk_type": "function", "name": "_log_image_predictions", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\comet.py", "start_line": 379, "end_line": 439, "start_col": 0, "end_col": 46, "parent_name": null, "docstring": "Log predicted boxes for a single image during training.\n\nThis function logs image predictions to a Comet ML experiment during model validation. It processes\nvalidation data and formats both ground truth and prediction annotations for visualization in the Comet\ndashboard. The function respects configured limits on the number of images to log.\n\nArgs:\n experiment (comet_ml.Experiment): The Comet ML experiment to log to.\n validator (BaseValidator): The validator instance containing validation data and predictions.\n curr_step (int): The current training step for logging timeline.\n\nNotes:\n This function uses global state to track the number of logged predictions across calls.\n It only logs predictions for supported tasks defined in COMET_SUPPORTED_TASKS.\n The number of logged images is limited by the COMET_MAX_IMAGE_PREDICTIONS environment variable.", "parameters": [ "experiment", "validator", "curr_step" ], "return_type": "None", "decorators": [], "complexity_score": 7, "dependencies": [ "collections.abc.Callable", "types.SimpleNamespace", "typing.Any", "typing.List", "typing.Optional", "cv2", "numpy", "ultralytics.utils.LOGGER", "ultralytics.utils.RANK", "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.ops", "ultralytics.utils.metrics.ClassifyMetrics", "ultralytics.utils.metrics.DetMetrics", "ultralytics.utils.metrics.OBBMetrics", "ultralytics.utils.metrics.PoseMetrics", "ultralytics.utils.metrics.SegmentMetrics", "comet_ml", "os", "pathlib.Path", "faster_coco_eval.core.mask.decode", "ultralytics.utils.torch_utils.model_info_for_loggers" ], "chunk_id": "function__log_image_predictions_0eafc348" }, { "content": "def _log_plots(experiment, trainer) -> None:\n \"\"\"\n Log evaluation plots and label plots for the experiment.\n\n This function logs various evaluation plots and confusion matrices to the experiment tracking system. It handles\n different types of metrics (SegmentMetrics, PoseMetrics, DetMetrics, OBBMetrics) and logs the appropriate plots\n for each type.\n\n Args:\n experiment (comet_ml.Experiment): The Comet ML experiment to log plots to.\n trainer (ultralytics.engine.trainer.BaseTrainer): The trainer object containing validation metrics and save\n directory information.\n\n Examples:\n >>> from ultralytics.utils.callbacks.comet import _log_plots\n >>> _log_plots(experiment, trainer)\n \"\"\"\n plot_filenames = None\n if isinstance(trainer.validator.metrics, SegmentMetrics):\n plot_filenames = [\n trainer.save_dir / f\"{prefix}{plots}.png\"\n for plots in EVALUATION_PLOT_NAMES\n for prefix in SEGMENT_METRICS_PLOT_PREFIX\n ]\n elif isinstance(trainer.validator.metrics, PoseMetrics):\n plot_filenames = [\n trainer.save_dir / f\"{prefix}{plots}.png\"\n for plots in EVALUATION_PLOT_NAMES\n for prefix in POSE_METRICS_PLOT_PREFIX\n ]\n elif isinstance(trainer.validator.metrics, (DetMetrics, OBBMetrics)):\n plot_filenames = [trainer.save_dir / f\"{plots}.png\" for plots in EVALUATION_PLOT_NAMES]\n\n if plot_filenames is not None:\n _log_images(experiment, plot_filenames, None)\n\n confusion_matrix_filenames = [trainer.save_dir / f\"{plots}.png\" for plots in CONFUSION_MATRIX_PLOT_NAMES]\n _log_images(experiment, confusion_matrix_filenames, None)\n\n if not isinstance(trainer.validator.metrics, ClassifyMetrics):\n label_plot_filenames = [trainer.save_dir / f\"{labels}.jpg\" for labels in LABEL_PLOT_NAMES]\n _log_images(experiment, label_plot_filenames, None)", "chunk_type": "function", "name": "_log_plots", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\comet.py", "start_line": 442, "end_line": 483, "start_col": 0, "end_col": 59, "parent_name": null, "docstring": "Log evaluation plots and label plots for the experiment.\n\nThis function logs various evaluation plots and confusion matrices to the experiment tracking system. It handles\ndifferent types of metrics (SegmentMetrics, PoseMetrics, DetMetrics, OBBMetrics) and logs the appropriate plots\nfor each type.\n\nArgs:\n experiment (comet_ml.Experiment): The Comet ML experiment to log plots to.\n trainer (ultralytics.engine.trainer.BaseTrainer): The trainer object containing validation metrics and save\n directory information.\n\nExamples:\n >>> from ultralytics.utils.callbacks.comet import _log_plots\n >>> _log_plots(experiment, trainer)", "parameters": [ "experiment", "trainer" ], "return_type": "None", "decorators": [], "complexity_score": 11, "dependencies": [ "collections.abc.Callable", "types.SimpleNamespace", "typing.Any", "typing.List", "typing.Optional", "cv2", "numpy", "ultralytics.utils.LOGGER", "ultralytics.utils.RANK", "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.ops", "ultralytics.utils.metrics.ClassifyMetrics", "ultralytics.utils.metrics.DetMetrics", "ultralytics.utils.metrics.OBBMetrics", "ultralytics.utils.metrics.PoseMetrics", "ultralytics.utils.metrics.SegmentMetrics", "comet_ml", "os", "pathlib.Path", "faster_coco_eval.core.mask.decode", "ultralytics.utils.torch_utils.model_info_for_loggers" ], "chunk_id": "function__log_plots_05fa3549" }, { "content": "def _log_model(experiment, trainer) -> None:\n \"\"\"Log the best-trained model to Comet.ml.\"\"\"\n model_name = _get_comet_model_name()\n experiment.log_model(model_name, file_or_folder=str(trainer.best), file_name=\"best.pt\", overwrite=True)", "chunk_type": "function", "name": "_log_model", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\comet.py", "start_line": 486, "end_line": 489, "start_col": 0, "end_col": 107, "parent_name": null, "docstring": "Log the best-trained model to Comet.ml.", "parameters": [ "experiment", "trainer" ], "return_type": "None", "decorators": [], "complexity_score": 1, "dependencies": [ "collections.abc.Callable", "types.SimpleNamespace", "typing.Any", "typing.List", "typing.Optional", "cv2", "numpy", "ultralytics.utils.LOGGER", "ultralytics.utils.RANK", "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.ops", "ultralytics.utils.metrics.ClassifyMetrics", "ultralytics.utils.metrics.DetMetrics", "ultralytics.utils.metrics.OBBMetrics", "ultralytics.utils.metrics.PoseMetrics", "ultralytics.utils.metrics.SegmentMetrics", "comet_ml", "os", "pathlib.Path", "faster_coco_eval.core.mask.decode", "ultralytics.utils.torch_utils.model_info_for_loggers" ], "chunk_id": "function__log_model_f0ac61d6" }, { "content": "def _log_image_batches(experiment, trainer, curr_step: int) -> None:\n \"\"\"Log samples of image batches for train, validation, and test.\"\"\"\n _log_images(experiment, trainer.save_dir.glob(\"train_batch*.jpg\"), curr_step)\n _log_images(experiment, trainer.save_dir.glob(\"val_batch*.jpg\"), curr_step)", "chunk_type": "function", "name": "_log_image_batches", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\comet.py", "start_line": 492, "end_line": 495, "start_col": 0, "end_col": 79, "parent_name": null, "docstring": "Log samples of image batches for train, validation, and test.", "parameters": [ "experiment", "trainer", "curr_step: int" ], "return_type": "None", "decorators": [], "complexity_score": 1, "dependencies": [ "collections.abc.Callable", "types.SimpleNamespace", "typing.Any", "typing.List", "typing.Optional", "cv2", "numpy", "ultralytics.utils.LOGGER", "ultralytics.utils.RANK", "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.ops", "ultralytics.utils.metrics.ClassifyMetrics", "ultralytics.utils.metrics.DetMetrics", "ultralytics.utils.metrics.OBBMetrics", "ultralytics.utils.metrics.PoseMetrics", "ultralytics.utils.metrics.SegmentMetrics", "comet_ml", "os", "pathlib.Path", "faster_coco_eval.core.mask.decode", "ultralytics.utils.torch_utils.model_info_for_loggers" ], "chunk_id": "function__log_image_batches_061ac9f1" }, { "content": "def on_pretrain_routine_start(trainer) -> None:\n \"\"\"Create or resume a CometML experiment at the start of a YOLO pre-training routine.\"\"\"\n _resume_or_create_experiment(trainer.args)", "chunk_type": "function", "name": "on_pretrain_routine_start", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\comet.py", "start_line": 498, "end_line": 500, "start_col": 0, "end_col": 46, "parent_name": null, "docstring": "Create or resume a CometML experiment at the start of a YOLO pre-training routine.", "parameters": [ "trainer" ], "return_type": "None", "decorators": [], "complexity_score": 1, "dependencies": [ "collections.abc.Callable", "types.SimpleNamespace", "typing.Any", "typing.List", "typing.Optional", "cv2", "numpy", "ultralytics.utils.LOGGER", "ultralytics.utils.RANK", "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.ops", "ultralytics.utils.metrics.ClassifyMetrics", "ultralytics.utils.metrics.DetMetrics", "ultralytics.utils.metrics.OBBMetrics", "ultralytics.utils.metrics.PoseMetrics", "ultralytics.utils.metrics.SegmentMetrics", "comet_ml", "os", "pathlib.Path", "faster_coco_eval.core.mask.decode", "ultralytics.utils.torch_utils.model_info_for_loggers" ], "chunk_id": "function_on_pretrain_routine_start_3223cb38" }, { "content": "def on_train_epoch_end(trainer) -> None:\n \"\"\"Log metrics and save batch images at the end of training epochs.\"\"\"\n experiment = comet_ml.get_running_experiment()\n if not experiment:\n return\n\n metadata = _fetch_trainer_metadata(trainer)\n curr_epoch = metadata[\"curr_epoch\"]\n curr_step = metadata[\"curr_step\"]\n\n experiment.log_metrics(trainer.label_loss_items(trainer.tloss, prefix=\"train\"), step=curr_step, epoch=curr_epoch)", "chunk_type": "function", "name": "on_train_epoch_end", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\comet.py", "start_line": 503, "end_line": 513, "start_col": 0, "end_col": 117, "parent_name": null, "docstring": "Log metrics and save batch images at the end of training epochs.", "parameters": [ "trainer" ], "return_type": "None", "decorators": [], "complexity_score": 2, "dependencies": [ "collections.abc.Callable", "types.SimpleNamespace", "typing.Any", "typing.List", "typing.Optional", "cv2", "numpy", "ultralytics.utils.LOGGER", "ultralytics.utils.RANK", "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.ops", "ultralytics.utils.metrics.ClassifyMetrics", "ultralytics.utils.metrics.DetMetrics", "ultralytics.utils.metrics.OBBMetrics", "ultralytics.utils.metrics.PoseMetrics", "ultralytics.utils.metrics.SegmentMetrics", "comet_ml", "os", "pathlib.Path", "faster_coco_eval.core.mask.decode", "ultralytics.utils.torch_utils.model_info_for_loggers" ], "chunk_id": "function_on_train_epoch_end_b5a2fb54" }, { "content": "def on_fit_epoch_end(trainer) -> None:\n \"\"\"\n Log model assets at the end of each epoch during training.\n\n This function is called at the end of each training epoch to log metrics, learning rates, and model information\n to a Comet ML experiment. It also logs model assets, confusion matrices, and image predictions based on\n configuration settings.\n\n The function retrieves the current Comet ML experiment and logs various training metrics. If it's the first epoch,\n it also logs model information. On specified save intervals, it logs the model, confusion matrix (if enabled),\n and image predictions (if enabled).\n\n Args:\n trainer (BaseTrainer): The YOLO trainer object containing training state, metrics, and configuration.\n\n Examples:\n >>> # Inside a training loop\n >>> on_fit_epoch_end(trainer) # Log metrics and assets to Comet ML\n \"\"\"\n experiment = comet_ml.get_running_experiment()\n if not experiment:\n return\n\n metadata = _fetch_trainer_metadata(trainer)\n curr_epoch = metadata[\"curr_epoch\"]\n curr_step = metadata[\"curr_step\"]\n save_assets = metadata[\"save_assets\"]\n\n experiment.log_metrics(trainer.metrics, step=curr_step, epoch=curr_epoch)\n experiment.log_metrics(trainer.lr, step=curr_step, epoch=curr_epoch)\n if curr_epoch == 1:\n from ultralytics.utils.torch_utils import model_info_for_loggers\n\n experiment.log_metrics(model_info_for_loggers(trainer), step=curr_step, epoch=curr_epoch)\n\n if not save_assets:\n return\n\n _log_model(experiment, trainer)\n if _should_log_confusion_matrix():\n _log_confusion_matrix(experiment, trainer, curr_step, curr_epoch)\n if _should_log_image_predictions():\n _log_image_predictions(experiment, trainer.validator, curr_step)", "chunk_type": "function", "name": "on_fit_epoch_end", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\comet.py", "start_line": 516, "end_line": 558, "start_col": 0, "end_col": 72, "parent_name": null, "docstring": "Log model assets at the end of each epoch during training.\n\nThis function is called at the end of each training epoch to log metrics, learning rates, and model information\nto a Comet ML experiment. It also logs model assets, confusion matrices, and image predictions based on\nconfiguration settings.\n\nThe function retrieves the current Comet ML experiment and logs various training metrics. If it's the first epoch,\nit also logs model information. On specified save intervals, it logs the model, confusion matrix (if enabled),\nand image predictions (if enabled).\n\nArgs:\n trainer (BaseTrainer): The YOLO trainer object containing training state, metrics, and configuration.\n\nExamples:\n >>> # Inside a training loop\n >>> on_fit_epoch_end(trainer) # Log metrics and assets to Comet ML", "parameters": [ "trainer" ], "return_type": "None", "decorators": [], "complexity_score": 6, "dependencies": [ "collections.abc.Callable", "types.SimpleNamespace", "typing.Any", "typing.List", "typing.Optional", "cv2", "numpy", "ultralytics.utils.LOGGER", "ultralytics.utils.RANK", "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.ops", "ultralytics.utils.metrics.ClassifyMetrics", "ultralytics.utils.metrics.DetMetrics", "ultralytics.utils.metrics.OBBMetrics", "ultralytics.utils.metrics.PoseMetrics", "ultralytics.utils.metrics.SegmentMetrics", "comet_ml", "os", "pathlib.Path", "faster_coco_eval.core.mask.decode", "ultralytics.utils.torch_utils.model_info_for_loggers" ], "chunk_id": "function_on_fit_epoch_end_9ea290e4" }, { "content": "def on_train_end(trainer) -> None:\n \"\"\"Perform operations at the end of training.\"\"\"\n experiment = comet_ml.get_running_experiment()\n if not experiment:\n return\n\n metadata = _fetch_trainer_metadata(trainer)\n curr_epoch = metadata[\"curr_epoch\"]\n curr_step = metadata[\"curr_step\"]\n plots = trainer.args.plots\n\n _log_model(experiment, trainer)\n if plots:\n _log_plots(experiment, trainer)\n\n _log_confusion_matrix(experiment, trainer, curr_step, curr_epoch)\n _log_image_predictions(experiment, trainer.validator, curr_step)\n _log_image_batches(experiment, trainer, curr_step)\n experiment.end()\n\n global _comet_image_prediction_count\n _comet_image_prediction_count = 0", "chunk_type": "function", "name": "on_train_end", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\comet.py", "start_line": 561, "end_line": 582, "start_col": 0, "end_col": 37, "parent_name": null, "docstring": "Perform operations at the end of training.", "parameters": [ "trainer" ], "return_type": "None", "decorators": [], "complexity_score": 3, "dependencies": [ "collections.abc.Callable", "types.SimpleNamespace", "typing.Any", "typing.List", "typing.Optional", "cv2", "numpy", "ultralytics.utils.LOGGER", "ultralytics.utils.RANK", "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.ops", "ultralytics.utils.metrics.ClassifyMetrics", "ultralytics.utils.metrics.DetMetrics", "ultralytics.utils.metrics.OBBMetrics", "ultralytics.utils.metrics.PoseMetrics", "ultralytics.utils.metrics.SegmentMetrics", "comet_ml", "os", "pathlib.Path", "faster_coco_eval.core.mask.decode", "ultralytics.utils.torch_utils.model_info_for_loggers" ], "chunk_id": "function_on_train_end_49d0684a" }, { "content": "callbacks = (\n {\n \"on_pretrain_routine_start\": on_pretrain_routine_start,\n \"on_train_epoch_end\": on_train_epoch_end,\n \"on_fit_epoch_end\": on_fit_epoch_end,\n \"on_train_end\": on_train_end,\n }\n if comet_ml\n else {}\n)", "chunk_type": "variable", "name": "callbacks", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\comet.py", "start_line": 585, "end_line": 594, "start_col": 0, "end_col": 1, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_callbacks_65d87ac9" }, { "content": "from pathlib import Path", "chunk_type": "import", "name": "Path", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\dvc.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 24, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_Path_35b05885" }, { "content": "from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING, checks", "chunk_type": "import", "name": "LOGGER, SETTINGS, TESTS_RUNNING, checks", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\dvc.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 69, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_LOGGER, SETTINGS, TESTS_RUNNING, checks_e208c880" }, { "content": "def _log_images(path: Path, prefix: str = \"\") -> None:\n \"\"\"\n Log images at specified path with an optional prefix using DVCLive.\n\n This function logs images found at the given path to DVCLive, organizing them by batch to enable slider\n functionality in the UI. It processes image filenames to extract batch information and restructures the path\n accordingly.\n\n Args:\n path (Path): Path to the image file to be logged.\n prefix (str, optional): Optional prefix to add to the image name when logging.\n\n Examples:\n >>> from pathlib import Path\n >>> _log_images(Path(\"runs/train/exp/val_batch0_pred.jpg\"), prefix=\"validation\")\n \"\"\"\n if live:\n name = path.name\n\n # Group images by batch to enable sliders in UI\n if m := re.search(r\"_batch(\\d+)\", name):\n ni = m[1]\n new_stem = re.sub(r\"_batch(\\d+)\", \"_batch\", path.stem)\n name = (Path(new_stem) / ni).with_suffix(path.suffix)\n\n live.log_image(os.path.join(prefix, name), path)", "chunk_type": "function", "name": "_log_images", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\dvc.py", "start_line": 29, "end_line": 54, "start_col": 0, "end_col": 56, "parent_name": null, "docstring": "Log images at specified path with an optional prefix using DVCLive.\n\nThis function logs images found at the given path to DVCLive, organizing them by batch to enable slider\nfunctionality in the UI. It processes image filenames to extract batch information and restructures the path\naccordingly.\n\nArgs:\n path (Path): Path to the image file to be logged.\n prefix (str, optional): Optional prefix to add to the image name when logging.\n\nExamples:\n >>> from pathlib import Path\n >>> _log_images(Path(\"runs/train/exp/val_batch0_pred.jpg\"), prefix=\"validation\")", "parameters": [ "path: Path", "prefix: str" ], "return_type": "None", "decorators": [], "complexity_score": 3, "dependencies": [ "pathlib.Path", "ultralytics.utils.LOGGER", "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.checks", "dvclive", "os", "re", "ultralytics.utils.torch_utils.model_info_for_loggers" ], "chunk_id": "function__log_images_76e8deb2" }, { "content": "def _log_plots(plots: dict, prefix: str = \"\") -> None:\n \"\"\"\n Log plot images for training progress if they have not been previously processed.\n\n Args:\n plots (dict): Dictionary containing plot information with timestamps.\n prefix (str, optional): Optional prefix to add to the logged image paths.\n \"\"\"\n for name, params in plots.items():\n timestamp = params[\"timestamp\"]\n if _processed_plots.get(name) != timestamp:\n _log_images(name, prefix)\n _processed_plots[name] = timestamp", "chunk_type": "function", "name": "_log_plots", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\dvc.py", "start_line": 57, "end_line": 69, "start_col": 0, "end_col": 46, "parent_name": null, "docstring": "Log plot images for training progress if they have not been previously processed.\n\nArgs:\n plots (dict): Dictionary containing plot information with timestamps.\n prefix (str, optional): Optional prefix to add to the logged image paths.", "parameters": [ "plots: dict", "prefix: str" ], "return_type": "None", "decorators": [], "complexity_score": 3, "dependencies": [ "pathlib.Path", "ultralytics.utils.LOGGER", "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.checks", "dvclive", "os", "re", "ultralytics.utils.torch_utils.model_info_for_loggers" ], "chunk_id": "function__log_plots_55416a93" }, { "content": "def _log_confusion_matrix(validator) -> None:\n \"\"\"\n Log confusion matrix for a validator using DVCLive.\n\n This function processes the confusion matrix from a validator object and logs it to DVCLive by converting\n the matrix into lists of target and prediction labels.\n\n Args:\n validator (BaseValidator): The validator object containing the confusion matrix and class names. Must have\n attributes: confusion_matrix.matrix, confusion_matrix.task, and names.\n \"\"\"\n targets = []\n preds = []\n matrix = validator.confusion_matrix.matrix\n names = list(validator.names.values())\n if validator.confusion_matrix.task == \"detect\":\n names += [\"background\"]\n\n for ti, pred in enumerate(matrix.T.astype(int)):\n for pi, num in enumerate(pred):\n targets.extend([names[ti]] * num)\n preds.extend([names[pi]] * num)\n\n live.log_sklearn_plot(\"confusion_matrix\", targets, preds, name=\"cf.json\", normalized=True)", "chunk_type": "function", "name": "_log_confusion_matrix", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\dvc.py", "start_line": 72, "end_line": 95, "start_col": 0, "end_col": 94, "parent_name": null, "docstring": "Log confusion matrix for a validator using DVCLive.\n\nThis function processes the confusion matrix from a validator object and logs it to DVCLive by converting\nthe matrix into lists of target and prediction labels.\n\nArgs:\n validator (BaseValidator): The validator object containing the confusion matrix and class names. Must have\n attributes: confusion_matrix.matrix, confusion_matrix.task, and names.", "parameters": [ "validator" ], "return_type": "None", "decorators": [], "complexity_score": 4, "dependencies": [ "pathlib.Path", "ultralytics.utils.LOGGER", "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.checks", "dvclive", "os", "re", "ultralytics.utils.torch_utils.model_info_for_loggers" ], "chunk_id": "function__log_confusion_matrix_9a6afa08" }, { "content": "def on_pretrain_routine_start(trainer) -> None:\n \"\"\"Initialize DVCLive logger for training metadata during pre-training routine.\"\"\"\n try:\n global live\n live = dvclive.Live(save_dvc_exp=True, cache_images=True)\n LOGGER.info(\"DVCLive is detected and auto logging is enabled (run 'yolo settings dvc=False' to disable).\")\n except Exception as e:\n LOGGER.warning(f\"DVCLive installed but not initialized correctly, not logging this run. {e}\")", "chunk_type": "function", "name": "on_pretrain_routine_start", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\dvc.py", "start_line": 98, "end_line": 105, "start_col": 0, "end_col": 101, "parent_name": null, "docstring": "Initialize DVCLive logger for training metadata during pre-training routine.", "parameters": [ "trainer" ], "return_type": "None", "decorators": [], "complexity_score": 2, "dependencies": [ "pathlib.Path", "ultralytics.utils.LOGGER", "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.checks", "dvclive", "os", "re", "ultralytics.utils.torch_utils.model_info_for_loggers" ], "chunk_id": "function_on_pretrain_routine_start_78643cf6" }, { "content": "def on_pretrain_routine_end(trainer) -> None:\n \"\"\"Log plots related to the training process at the end of the pretraining routine.\"\"\"\n _log_plots(trainer.plots, \"train\")", "chunk_type": "function", "name": "on_pretrain_routine_end", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\dvc.py", "start_line": 108, "end_line": 110, "start_col": 0, "end_col": 38, "parent_name": null, "docstring": "Log plots related to the training process at the end of the pretraining routine.", "parameters": [ "trainer" ], "return_type": "None", "decorators": [], "complexity_score": 1, "dependencies": [ "pathlib.Path", "ultralytics.utils.LOGGER", "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.checks", "dvclive", "os", "re", "ultralytics.utils.torch_utils.model_info_for_loggers" ], "chunk_id": "function_on_pretrain_routine_end_3e9575ea" }, { "content": "def on_train_start(trainer) -> None:\n \"\"\"Log the training parameters if DVCLive logging is active.\"\"\"\n if live:\n live.log_params(trainer.args)", "chunk_type": "function", "name": "on_train_start", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\dvc.py", "start_line": 113, "end_line": 116, "start_col": 0, "end_col": 37, "parent_name": null, "docstring": "Log the training parameters if DVCLive logging is active.", "parameters": [ "trainer" ], "return_type": "None", "decorators": [], "complexity_score": 2, "dependencies": [ "pathlib.Path", "ultralytics.utils.LOGGER", "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.checks", "dvclive", "os", "re", "ultralytics.utils.torch_utils.model_info_for_loggers" ], "chunk_id": "function_on_train_start_af29dda1" }, { "content": "def on_train_epoch_start(trainer) -> None:\n \"\"\"Set the global variable _training_epoch value to True at the start of training each epoch.\"\"\"\n global _training_epoch\n _training_epoch = True", "chunk_type": "function", "name": "on_train_epoch_start", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\dvc.py", "start_line": 119, "end_line": 122, "start_col": 0, "end_col": 26, "parent_name": null, "docstring": "Set the global variable _training_epoch value to True at the start of training each epoch.", "parameters": [ "trainer" ], "return_type": "None", "decorators": [], "complexity_score": 1, "dependencies": [ "pathlib.Path", "ultralytics.utils.LOGGER", "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.checks", "dvclive", "os", "re", "ultralytics.utils.torch_utils.model_info_for_loggers" ], "chunk_id": "function_on_train_epoch_start_02ab2515" }, { "content": "def on_fit_epoch_end(trainer) -> None:\n \"\"\"\n Log training metrics, model info, and advance to next step at the end of each fit epoch.\n\n This function is called at the end of each fit epoch during training. It logs various metrics including\n training loss items, validation metrics, and learning rates. On the first epoch, it also logs model\n information. Additionally, it logs training and validation plots and advances the DVCLive step counter.\n\n Args:\n trainer (BaseTrainer): The trainer object containing training state, metrics, and plots.\n\n Notes:\n This function only performs logging operations when DVCLive logging is active and during a training epoch.\n The global variable _training_epoch is used to track whether the current epoch is a training epoch.\n \"\"\"\n global _training_epoch\n if live and _training_epoch:\n all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix=\"train\"), **trainer.metrics, **trainer.lr}\n for metric, value in all_metrics.items():\n live.log_metric(metric, value)\n\n if trainer.epoch == 0:\n from ultralytics.utils.torch_utils import model_info_for_loggers\n\n for metric, value in model_info_for_loggers(trainer).items():\n live.log_metric(metric, value, plot=False)\n\n _log_plots(trainer.plots, \"train\")\n _log_plots(trainer.validator.plots, \"val\")\n\n live.next_step()\n _training_epoch = False", "chunk_type": "function", "name": "on_fit_epoch_end", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\dvc.py", "start_line": 125, "end_line": 156, "start_col": 0, "end_col": 31, "parent_name": null, "docstring": "Log training metrics, model info, and advance to next step at the end of each fit epoch.\n\nThis function is called at the end of each fit epoch during training. It logs various metrics including\ntraining loss items, validation metrics, and learning rates. On the first epoch, it also logs model\ninformation. Additionally, it logs training and validation plots and advances the DVCLive step counter.\n\nArgs:\n trainer (BaseTrainer): The trainer object containing training state, metrics, and plots.\n\nNotes:\n This function only performs logging operations when DVCLive logging is active and during a training epoch.\n The global variable _training_epoch is used to track whether the current epoch is a training epoch.", "parameters": [ "trainer" ], "return_type": "None", "decorators": [], "complexity_score": 5, "dependencies": [ "pathlib.Path", "ultralytics.utils.LOGGER", "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.checks", "dvclive", "os", "re", "ultralytics.utils.torch_utils.model_info_for_loggers" ], "chunk_id": "function_on_fit_epoch_end_88ab55a6" }, { "content": "def on_train_end(trainer) -> None:\n \"\"\"\n Log best metrics, plots, and confusion matrix at the end of training.\n\n This function is called at the conclusion of the training process to log final metrics, visualizations, and\n model artifacts if DVCLive logging is active. It captures the best model performance metrics, training plots,\n validation plots, and confusion matrix for later analysis.\n\n Args:\n trainer (BaseTrainer): The trainer object containing training state, metrics, and validation results.\n\n Examples:\n >>> # Inside a custom training loop\n >>> from ultralytics.utils.callbacks.dvc import on_train_end\n >>> on_train_end(trainer) # Log final metrics and artifacts\n \"\"\"\n if live:\n # At the end log the best metrics. It runs validator on the best model internally.\n all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix=\"train\"), **trainer.metrics, **trainer.lr}\n for metric, value in all_metrics.items():\n live.log_metric(metric, value, plot=False)\n\n _log_plots(trainer.plots, \"val\")\n _log_plots(trainer.validator.plots, \"val\")\n _log_confusion_matrix(trainer.validator)\n\n if trainer.best.exists():\n live.log_artifact(trainer.best, copy=True, type=\"model\")\n\n live.end()", "chunk_type": "function", "name": "on_train_end", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\dvc.py", "start_line": 159, "end_line": 188, "start_col": 0, "end_col": 18, "parent_name": null, "docstring": "Log best metrics, plots, and confusion matrix at the end of training.\n\nThis function is called at the conclusion of the training process to log final metrics, visualizations, and\nmodel artifacts if DVCLive logging is active. It captures the best model performance metrics, training plots,\nvalidation plots, and confusion matrix for later analysis.\n\nArgs:\n trainer (BaseTrainer): The trainer object containing training state, metrics, and validation results.\n\nExamples:\n >>> # Inside a custom training loop\n >>> from ultralytics.utils.callbacks.dvc import on_train_end\n >>> on_train_end(trainer) # Log final metrics and artifacts", "parameters": [ "trainer" ], "return_type": "None", "decorators": [], "complexity_score": 4, "dependencies": [ "pathlib.Path", "ultralytics.utils.LOGGER", "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.checks", "dvclive", "os", "re", "ultralytics.utils.torch_utils.model_info_for_loggers" ], "chunk_id": "function_on_train_end_10251516" }, { "content": "callbacks = (\n {\n \"on_pretrain_routine_start\": on_pretrain_routine_start,\n \"on_pretrain_routine_end\": on_pretrain_routine_end,\n \"on_train_start\": on_train_start,\n \"on_train_epoch_start\": on_train_epoch_start,\n \"on_fit_epoch_end\": on_fit_epoch_end,\n \"on_train_end\": on_train_end,\n }\n if dvclive\n else {}\n)", "chunk_type": "variable", "name": "callbacks", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\dvc.py", "start_line": 191, "end_line": 202, "start_col": 0, "end_col": 1, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_callbacks_063bbfb2" }, { "content": "import json", "chunk_type": "import", "name": "json", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\hub.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 11, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_json_cbf8f6e7" }, { "content": "from time import time", "chunk_type": "import", "name": "time", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\hub.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 21, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_time_661e5670" }, { "content": "from ultralytics.hub import HUB_WEB_ROOT, PREFIX, HUBTrainingSession, events", "chunk_type": "import", "name": "HUB_WEB_ROOT, PREFIX, HUBTrainingSession, events", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\hub.py", "start_line": 6, "end_line": 6, "start_col": 0, "end_col": 76, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_HUB_WEB_ROOT, PREFIX, HUBTrainingSession, events_fee7c9ca" }, { "content": "from ultralytics.utils import LOGGER, RANK, SETTINGS", "chunk_type": "import", "name": "LOGGER, RANK, SETTINGS", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\hub.py", "start_line": 7, "end_line": 7, "start_col": 0, "end_col": 52, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_LOGGER, RANK, SETTINGS_f6dce9b9" }, { "content": "def on_pretrain_routine_start(trainer):\n \"\"\"Create a remote Ultralytics HUB session to log local model training.\"\"\"\n if RANK in {-1, 0} and SETTINGS[\"hub\"] is True and SETTINGS[\"api_key\"] and trainer.hub_session is None:\n trainer.hub_session = HUBTrainingSession.create_session(trainer.args.model, trainer.args)", "chunk_type": "function", "name": "on_pretrain_routine_start", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\hub.py", "start_line": 10, "end_line": 13, "start_col": 0, "end_col": 97, "parent_name": null, "docstring": "Create a remote Ultralytics HUB session to log local model training.", "parameters": [ "trainer" ], "return_type": null, "decorators": [], "complexity_score": 2, "dependencies": [ "json", "time.time", "ultralytics.hub.HUB_WEB_ROOT", "ultralytics.hub.PREFIX", "ultralytics.hub.HUBTrainingSession", "ultralytics.hub.events", "ultralytics.utils.LOGGER", "ultralytics.utils.RANK", "ultralytics.utils.SETTINGS", "ultralytics.utils.torch_utils.model_info_for_loggers" ], "chunk_id": "function_on_pretrain_routine_start_d36e6ce3" }, { "content": "def on_pretrain_routine_end(trainer):\n \"\"\"Initialize timers for upload rate limiting before training begins.\"\"\"\n if session := getattr(trainer, \"hub_session\", None):\n # Start timer for upload rate limit\n session.timers = {\"metrics\": time(), \"ckpt\": time()} # start timer for session rate limiting", "chunk_type": "function", "name": "on_pretrain_routine_end", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\hub.py", "start_line": 16, "end_line": 20, "start_col": 0, "end_col": 60, "parent_name": null, "docstring": "Initialize timers for upload rate limiting before training begins.", "parameters": [ "trainer" ], "return_type": null, "decorators": [], "complexity_score": 2, "dependencies": [ "json", "time.time", "ultralytics.hub.HUB_WEB_ROOT", "ultralytics.hub.PREFIX", "ultralytics.hub.HUBTrainingSession", "ultralytics.hub.events", "ultralytics.utils.LOGGER", "ultralytics.utils.RANK", "ultralytics.utils.SETTINGS", "ultralytics.utils.torch_utils.model_info_for_loggers" ], "chunk_id": "function_on_pretrain_routine_end_8df60b02" }, { "content": "def on_fit_epoch_end(trainer):\n \"\"\"Upload training progress metrics to Ultralytics HUB at the end of each epoch.\"\"\"\n if session := getattr(trainer, \"hub_session\", None):\n # Upload metrics after validation ends\n all_plots = {\n **trainer.label_loss_items(trainer.tloss, prefix=\"train\"),\n **trainer.metrics,\n }\n if trainer.epoch == 0:\n from ultralytics.utils.torch_utils import model_info_for_loggers\n\n all_plots = {**all_plots, **model_info_for_loggers(trainer)}\n\n session.metrics_queue[trainer.epoch] = json.dumps(all_plots)\n\n # If any metrics failed to upload previously, add them to the queue to attempt uploading again\n if session.metrics_upload_failed_queue:\n session.metrics_queue.update(session.metrics_upload_failed_queue)\n\n if time() - session.timers[\"metrics\"] > session.rate_limits[\"metrics\"]:\n session.upload_metrics()\n session.timers[\"metrics\"] = time() # reset timer\n session.metrics_queue = {} # reset queue", "chunk_type": "function", "name": "on_fit_epoch_end", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\hub.py", "start_line": 23, "end_line": 45, "start_col": 0, "end_col": 38, "parent_name": null, "docstring": "Upload training progress metrics to Ultralytics HUB at the end of each epoch.", "parameters": [ "trainer" ], "return_type": null, "decorators": [], "complexity_score": 5, "dependencies": [ "json", "time.time", "ultralytics.hub.HUB_WEB_ROOT", "ultralytics.hub.PREFIX", "ultralytics.hub.HUBTrainingSession", "ultralytics.hub.events", "ultralytics.utils.LOGGER", "ultralytics.utils.RANK", "ultralytics.utils.SETTINGS", "ultralytics.utils.torch_utils.model_info_for_loggers" ], "chunk_id": "function_on_fit_epoch_end_03530e00" }, { "content": "def on_model_save(trainer):\n \"\"\"Upload model checkpoints to Ultralytics HUB with rate limiting.\"\"\"\n if session := getattr(trainer, \"hub_session\", None):\n # Upload checkpoints with rate limiting\n is_best = trainer.best_fitness == trainer.fitness\n if time() - session.timers[\"ckpt\"] > session.rate_limits[\"ckpt\"]:\n LOGGER.info(f\"{PREFIX}Uploading checkpoint {HUB_WEB_ROOT}/models/{session.model.id}\")\n session.upload_model(trainer.epoch, trainer.last, is_best)\n session.timers[\"ckpt\"] = time() # reset timer", "chunk_type": "function", "name": "on_model_save", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\hub.py", "start_line": 48, "end_line": 56, "start_col": 0, "end_col": 43, "parent_name": null, "docstring": "Upload model checkpoints to Ultralytics HUB with rate limiting.", "parameters": [ "trainer" ], "return_type": null, "decorators": [], "complexity_score": 3, "dependencies": [ "json", "time.time", "ultralytics.hub.HUB_WEB_ROOT", "ultralytics.hub.PREFIX", "ultralytics.hub.HUBTrainingSession", "ultralytics.hub.events", "ultralytics.utils.LOGGER", "ultralytics.utils.RANK", "ultralytics.utils.SETTINGS", "ultralytics.utils.torch_utils.model_info_for_loggers" ], "chunk_id": "function_on_model_save_76486db7" }, { "content": "def on_train_end(trainer):\n \"\"\"Upload final model and metrics to Ultralytics HUB at the end of training.\"\"\"\n if session := getattr(trainer, \"hub_session\", None):\n # Upload final model and metrics with exponential standoff\n LOGGER.info(f\"{PREFIX}Syncing final model...\")\n session.upload_model(\n trainer.epoch,\n trainer.best,\n map=trainer.metrics.get(\"metrics/mAP50-95(B)\", 0),\n final=True,\n )\n session.alive = False # stop heartbeats\n LOGGER.info(f\"{PREFIX}Done ✅\\n{PREFIX}View model at {session.model_url} 🚀\")", "chunk_type": "function", "name": "on_train_end", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\hub.py", "start_line": 59, "end_line": 71, "start_col": 0, "end_col": 88, "parent_name": null, "docstring": "Upload final model and metrics to Ultralytics HUB at the end of training.", "parameters": [ "trainer" ], "return_type": null, "decorators": [], "complexity_score": 2, "dependencies": [ "json", "time.time", "ultralytics.hub.HUB_WEB_ROOT", "ultralytics.hub.PREFIX", "ultralytics.hub.HUBTrainingSession", "ultralytics.hub.events", "ultralytics.utils.LOGGER", "ultralytics.utils.RANK", "ultralytics.utils.SETTINGS", "ultralytics.utils.torch_utils.model_info_for_loggers" ], "chunk_id": "function_on_train_end_82e21297" }, { "content": "def on_train_start(trainer):\n \"\"\"Run events on train start.\"\"\"\n events(trainer.args, trainer.device)", "chunk_type": "function", "name": "on_train_start", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\hub.py", "start_line": 74, "end_line": 76, "start_col": 0, "end_col": 40, "parent_name": null, "docstring": "Run events on train start.", "parameters": [ "trainer" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "json", "time.time", "ultralytics.hub.HUB_WEB_ROOT", "ultralytics.hub.PREFIX", "ultralytics.hub.HUBTrainingSession", "ultralytics.hub.events", "ultralytics.utils.LOGGER", "ultralytics.utils.RANK", "ultralytics.utils.SETTINGS", "ultralytics.utils.torch_utils.model_info_for_loggers" ], "chunk_id": "function_on_train_start_91192112" }, { "content": "def on_val_start(validator):\n \"\"\"Run events on validation start.\"\"\"\n if not validator.training:\n events(validator.args, validator.device)", "chunk_type": "function", "name": "on_val_start", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\hub.py", "start_line": 79, "end_line": 82, "start_col": 0, "end_col": 48, "parent_name": null, "docstring": "Run events on validation start.", "parameters": [ "validator" ], "return_type": null, "decorators": [], "complexity_score": 2, "dependencies": [ "json", "time.time", "ultralytics.hub.HUB_WEB_ROOT", "ultralytics.hub.PREFIX", "ultralytics.hub.HUBTrainingSession", "ultralytics.hub.events", "ultralytics.utils.LOGGER", "ultralytics.utils.RANK", "ultralytics.utils.SETTINGS", "ultralytics.utils.torch_utils.model_info_for_loggers" ], "chunk_id": "function_on_val_start_16876252" }, { "content": "def on_predict_start(predictor):\n \"\"\"Run events on predict start.\"\"\"\n events(predictor.args, predictor.device)", "chunk_type": "function", "name": "on_predict_start", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\hub.py", "start_line": 85, "end_line": 87, "start_col": 0, "end_col": 44, "parent_name": null, "docstring": "Run events on predict start.", "parameters": [ "predictor" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "json", "time.time", "ultralytics.hub.HUB_WEB_ROOT", "ultralytics.hub.PREFIX", "ultralytics.hub.HUBTrainingSession", "ultralytics.hub.events", "ultralytics.utils.LOGGER", "ultralytics.utils.RANK", "ultralytics.utils.SETTINGS", "ultralytics.utils.torch_utils.model_info_for_loggers" ], "chunk_id": "function_on_predict_start_6613bf11" }, { "content": "def on_export_start(exporter):\n \"\"\"Run events on export start.\"\"\"\n events(exporter.args, exporter.device)", "chunk_type": "function", "name": "on_export_start", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\hub.py", "start_line": 90, "end_line": 92, "start_col": 0, "end_col": 42, "parent_name": null, "docstring": "Run events on export start.", "parameters": [ "exporter" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "json", "time.time", "ultralytics.hub.HUB_WEB_ROOT", "ultralytics.hub.PREFIX", "ultralytics.hub.HUBTrainingSession", "ultralytics.hub.events", "ultralytics.utils.LOGGER", "ultralytics.utils.RANK", "ultralytics.utils.SETTINGS", "ultralytics.utils.torch_utils.model_info_for_loggers" ], "chunk_id": "function_on_export_start_c3b5094d" }, { "content": "callbacks = (\n {\n \"on_pretrain_routine_start\": on_pretrain_routine_start,\n \"on_pretrain_routine_end\": on_pretrain_routine_end,\n \"on_fit_epoch_end\": on_fit_epoch_end,\n \"on_model_save\": on_model_save,\n \"on_train_end\": on_train_end,\n \"on_train_start\": on_train_start,\n \"on_val_start\": on_val_start,\n \"on_predict_start\": on_predict_start,\n \"on_export_start\": on_export_start,\n }\n if SETTINGS[\"hub\"] is True\n else {}\n) # verify hub is enabled before registering callbacks", "chunk_type": "variable", "name": "callbacks", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\hub.py", "start_line": 95, "end_line": 109, "start_col": 0, "end_col": 1, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_callbacks_88371912" }, { "content": "from ultralytics.utils import LOGGER, RUNS_DIR, SETTINGS, TESTS_RUNNING, colorstr", "chunk_type": "import", "name": "LOGGER, RUNS_DIR, SETTINGS, TESTS_RUNNING, colorstr", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\mlflow.py", "start_line": 24, "end_line": 24, "start_col": 0, "end_col": 81, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_LOGGER, RUNS_DIR, SETTINGS, TESTS_RUNNING, colorstr_d1eb9590" }, { "content": "def sanitize_dict(x: dict) -> dict:\n \"\"\"Sanitize dictionary keys by removing parentheses and converting values to floats.\"\"\"\n return {k.replace(\"(\", \"\").replace(\")\", \"\"): float(v) for k, v in x.items()}", "chunk_type": "function", "name": "sanitize_dict", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\mlflow.py", "start_line": 42, "end_line": 44, "start_col": 0, "end_col": 80, "parent_name": null, "docstring": "Sanitize dictionary keys by removing parentheses and converting values to floats.", "parameters": [ "x: dict" ], "return_type": "dict", "decorators": [], "complexity_score": 2, "dependencies": [ "ultralytics.utils.LOGGER", "ultralytics.utils.RUNS_DIR", "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.colorstr", "os", "mlflow", "pathlib.Path" ], "chunk_id": "function_sanitize_dict_3cc1ce93" }, { "content": "def on_pretrain_routine_end(trainer):\n \"\"\"\n Log training parameters to MLflow at the end of the pretraining routine.\n\n This function sets up MLflow logging based on environment variables and trainer arguments. It sets the tracking URI,\n experiment name, and run name, then starts the MLflow run if not already active. It finally logs the parameters\n from the trainer.\n\n Args:\n trainer (ultralytics.engine.trainer.BaseTrainer): The training object with arguments and parameters to log.\n\n Environment Variables:\n MLFLOW_TRACKING_URI: The URI for MLflow tracking. If not set, defaults to 'runs/mlflow'.\n MLFLOW_EXPERIMENT_NAME: The name of the MLflow experiment. If not set, defaults to trainer.args.project.\n MLFLOW_RUN: The name of the MLflow run. If not set, defaults to trainer.args.name.\n MLFLOW_KEEP_RUN_ACTIVE: Boolean indicating whether to keep the MLflow run active after training ends.\n \"\"\"\n global mlflow\n\n uri = os.environ.get(\"MLFLOW_TRACKING_URI\") or str(RUNS_DIR / \"mlflow\")\n LOGGER.debug(f\"{PREFIX} tracking uri: {uri}\")\n mlflow.set_tracking_uri(uri)\n\n # Set experiment and run names\n experiment_name = os.environ.get(\"MLFLOW_EXPERIMENT_NAME\") or trainer.args.project or \"/Shared/Ultralytics\"\n run_name = os.environ.get(\"MLFLOW_RUN\") or trainer.args.name\n mlflow.set_experiment(experiment_name)\n\n mlflow.autolog()\n try:\n active_run = mlflow.active_run() or mlflow.start_run(run_name=run_name)\n LOGGER.info(f\"{PREFIX}logging run_id({active_run.info.run_id}) to {uri}\")\n if Path(uri).is_dir():\n LOGGER.info(f\"{PREFIX}view at http://127.0.0.1:5000 with 'mlflow server --backend-store-uri {uri}'\")\n LOGGER.info(f\"{PREFIX}disable with 'yolo settings mlflow=False'\")\n mlflow.log_params(dict(trainer.args))\n except Exception as e:\n LOGGER.warning(f\"{PREFIX}Failed to initialize: {e}\")\n LOGGER.warning(f\"{PREFIX}Not tracking this run\")", "chunk_type": "function", "name": "on_pretrain_routine_end", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\mlflow.py", "start_line": 47, "end_line": 85, "start_col": 0, "end_col": 56, "parent_name": null, "docstring": "Log training parameters to MLflow at the end of the pretraining routine.\n\nThis function sets up MLflow logging based on environment variables and trainer arguments. It sets the tracking URI,\nexperiment name, and run name, then starts the MLflow run if not already active. It finally logs the parameters\nfrom the trainer.\n\nArgs:\n trainer (ultralytics.engine.trainer.BaseTrainer): The training object with arguments and parameters to log.\n\nEnvironment Variables:\n MLFLOW_TRACKING_URI: The URI for MLflow tracking. If not set, defaults to 'runs/mlflow'.\n MLFLOW_EXPERIMENT_NAME: The name of the MLflow experiment. If not set, defaults to trainer.args.project.\n MLFLOW_RUN: The name of the MLflow run. If not set, defaults to trainer.args.name.\n MLFLOW_KEEP_RUN_ACTIVE: Boolean indicating whether to keep the MLflow run active after training ends.", "parameters": [ "trainer" ], "return_type": null, "decorators": [], "complexity_score": 3, "dependencies": [ "ultralytics.utils.LOGGER", "ultralytics.utils.RUNS_DIR", "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.colorstr", "os", "mlflow", "pathlib.Path" ], "chunk_id": "function_on_pretrain_routine_end_aabdf4e2" }, { "content": "def on_train_epoch_end(trainer):\n \"\"\"Log training metrics at the end of each train epoch to MLflow.\"\"\"\n if mlflow:\n mlflow.log_metrics(\n metrics={\n **sanitize_dict(trainer.lr),\n **sanitize_dict(trainer.label_loss_items(trainer.tloss, prefix=\"train\")),\n },\n step=trainer.epoch,\n )", "chunk_type": "function", "name": "on_train_epoch_end", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\mlflow.py", "start_line": 88, "end_line": 97, "start_col": 0, "end_col": 9, "parent_name": null, "docstring": "Log training metrics at the end of each train epoch to MLflow.", "parameters": [ "trainer" ], "return_type": null, "decorators": [], "complexity_score": 2, "dependencies": [ "ultralytics.utils.LOGGER", "ultralytics.utils.RUNS_DIR", "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.colorstr", "os", "mlflow", "pathlib.Path" ], "chunk_id": "function_on_train_epoch_end_96d2c478" }, { "content": "def on_fit_epoch_end(trainer):\n \"\"\"Log training metrics at the end of each fit epoch to MLflow.\"\"\"\n if mlflow:\n mlflow.log_metrics(metrics=sanitize_dict(trainer.metrics), step=trainer.epoch)", "chunk_type": "function", "name": "on_fit_epoch_end", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\mlflow.py", "start_line": 100, "end_line": 103, "start_col": 0, "end_col": 86, "parent_name": null, "docstring": "Log training metrics at the end of each fit epoch to MLflow.", "parameters": [ "trainer" ], "return_type": null, "decorators": [], "complexity_score": 2, "dependencies": [ "ultralytics.utils.LOGGER", "ultralytics.utils.RUNS_DIR", "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.colorstr", "os", "mlflow", "pathlib.Path" ], "chunk_id": "function_on_fit_epoch_end_c86aae8a" }, { "content": "def on_train_end(trainer):\n \"\"\"Log model artifacts at the end of training.\"\"\"\n if not mlflow:\n return\n mlflow.log_artifact(str(trainer.best.parent)) # log save_dir/weights directory with best.pt and last.pt\n for f in trainer.save_dir.glob(\"*\"): # log all other files in save_dir\n if f.suffix in {\".png\", \".jpg\", \".csv\", \".pt\", \".yaml\"}:\n mlflow.log_artifact(str(f))\n keep_run_active = os.environ.get(\"MLFLOW_KEEP_RUN_ACTIVE\", \"False\").lower() == \"true\"\n if keep_run_active:\n LOGGER.info(f\"{PREFIX}mlflow run still alive, remember to close it using mlflow.end_run()\")\n else:\n mlflow.end_run()\n LOGGER.debug(f\"{PREFIX}mlflow run ended\")\n\n LOGGER.info(\n f\"{PREFIX}results logged to {mlflow.get_tracking_uri()}\\n{PREFIX}disable with 'yolo settings mlflow=False'\"\n )", "chunk_type": "function", "name": "on_train_end", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\mlflow.py", "start_line": 106, "end_line": 123, "start_col": 0, "end_col": 5, "parent_name": null, "docstring": "Log model artifacts at the end of training.", "parameters": [ "trainer" ], "return_type": null, "decorators": [], "complexity_score": 5, "dependencies": [ "ultralytics.utils.LOGGER", "ultralytics.utils.RUNS_DIR", "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.colorstr", "os", "mlflow", "pathlib.Path" ], "chunk_id": "function_on_train_end_d16b581b" }, { "content": "callbacks = (\n {\n \"on_pretrain_routine_end\": on_pretrain_routine_end,\n \"on_train_epoch_end\": on_train_epoch_end,\n \"on_fit_epoch_end\": on_fit_epoch_end,\n \"on_train_end\": on_train_end,\n }\n if mlflow\n else {}\n)", "chunk_type": "variable", "name": "callbacks", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\mlflow.py", "start_line": 126, "end_line": 135, "start_col": 0, "end_col": 1, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_callbacks_dd07e529" }, { "content": "from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING", "chunk_type": "import", "name": "LOGGER, SETTINGS, TESTS_RUNNING", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\neptune.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 61, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_LOGGER, SETTINGS, TESTS_RUNNING_e1381ec7" }, { "content": "def _log_scalars(scalars: dict, step: int = 0) -> None:\n \"\"\"\n Log scalars to the NeptuneAI experiment logger.\n\n Args:\n scalars (dict): Dictionary of scalar values to log to NeptuneAI.\n step (int, optional): The current step or iteration number for logging.\n\n Examples:\n >>> metrics = {\"mAP\": 0.85, \"loss\": 0.32}\n >>> _log_scalars(metrics, step=100)\n \"\"\"\n if run:\n for k, v in scalars.items():\n run[k].append(value=v, step=step)", "chunk_type": "function", "name": "_log_scalars", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\neptune.py", "start_line": 20, "end_line": 34, "start_col": 0, "end_col": 45, "parent_name": null, "docstring": "Log scalars to the NeptuneAI experiment logger.\n\nArgs:\n scalars (dict): Dictionary of scalar values to log to NeptuneAI.\n step (int, optional): The current step or iteration number for logging.\n\nExamples:\n >>> metrics = {\"mAP\": 0.85, \"loss\": 0.32}\n >>> _log_scalars(metrics, step=100)", "parameters": [ "scalars: dict", "step: int" ], "return_type": "None", "decorators": [], "complexity_score": 3, "dependencies": [ "ultralytics.utils.LOGGER", "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "neptune", "neptune.types.File", "matplotlib.image", "matplotlib.pyplot", "ultralytics.utils.torch_utils.model_info_for_loggers" ], "chunk_id": "function__log_scalars_775f6960" }, { "content": "def _log_images(imgs_dict: dict, group: str = \"\") -> None:\n \"\"\"\n Log images to the NeptuneAI experiment logger.\n\n This function logs image data to Neptune.ai when a valid Neptune run is active. Images are organized\n under the specified group name.\n\n Args:\n imgs_dict (dict): Dictionary of images to log, with keys as image names and values as image data.\n group (str, optional): Group name to organize images under in the Neptune UI.\n\n Examples:\n >>> # Log validation images\n >>> _log_images({\"val_batch\": img_tensor}, group=\"validation\")\n \"\"\"\n if run:\n for k, v in imgs_dict.items():\n run[f\"{group}/{k}\"].upload(File(v))", "chunk_type": "function", "name": "_log_images", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\neptune.py", "start_line": 37, "end_line": 54, "start_col": 0, "end_col": 47, "parent_name": null, "docstring": "Log images to the NeptuneAI experiment logger.\n\nThis function logs image data to Neptune.ai when a valid Neptune run is active. Images are organized\nunder the specified group name.\n\nArgs:\n imgs_dict (dict): Dictionary of images to log, with keys as image names and values as image data.\n group (str, optional): Group name to organize images under in the Neptune UI.\n\nExamples:\n >>> # Log validation images\n >>> _log_images({\"val_batch\": img_tensor}, group=\"validation\")", "parameters": [ "imgs_dict: dict", "group: str" ], "return_type": "None", "decorators": [], "complexity_score": 3, "dependencies": [ "ultralytics.utils.LOGGER", "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "neptune", "neptune.types.File", "matplotlib.image", "matplotlib.pyplot", "ultralytics.utils.torch_utils.model_info_for_loggers" ], "chunk_id": "function__log_images_a7b4daf0" }, { "content": "def _log_plot(title: str, plot_path: str) -> None:\n \"\"\"Log plots to the NeptuneAI experiment logger.\"\"\"\n import matplotlib.image as mpimg\n import matplotlib.pyplot as plt\n\n img = mpimg.imread(plot_path)\n fig = plt.figure()\n ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect=\"auto\", xticks=[], yticks=[]) # no ticks\n ax.imshow(img)\n run[f\"Plots/{title}\"].upload(fig)", "chunk_type": "function", "name": "_log_plot", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\neptune.py", "start_line": 57, "end_line": 66, "start_col": 0, "end_col": 37, "parent_name": null, "docstring": "Log plots to the NeptuneAI experiment logger.", "parameters": [ "title: str", "plot_path: str" ], "return_type": "None", "decorators": [], "complexity_score": 1, "dependencies": [ "ultralytics.utils.LOGGER", "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "neptune", "neptune.types.File", "matplotlib.image", "matplotlib.pyplot", "ultralytics.utils.torch_utils.model_info_for_loggers" ], "chunk_id": "function__log_plot_80d7545a" }, { "content": "def on_pretrain_routine_start(trainer) -> None:\n \"\"\"Initialize NeptuneAI run and log hyperparameters before training starts.\"\"\"\n try:\n global run\n run = neptune.init_run(\n project=trainer.args.project or \"Ultralytics\",\n name=trainer.args.name,\n tags=[\"Ultralytics\"],\n )\n run[\"Configuration/Hyperparameters\"] = {k: \"\" if v is None else v for k, v in vars(trainer.args).items()}\n except Exception as e:\n LOGGER.warning(f\"NeptuneAI installed but not initialized correctly, not logging this run. {e}\")", "chunk_type": "function", "name": "on_pretrain_routine_start", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\neptune.py", "start_line": 69, "end_line": 80, "start_col": 0, "end_col": 103, "parent_name": null, "docstring": "Initialize NeptuneAI run and log hyperparameters before training starts.", "parameters": [ "trainer" ], "return_type": "None", "decorators": [], "complexity_score": 3, "dependencies": [ "ultralytics.utils.LOGGER", "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "neptune", "neptune.types.File", "matplotlib.image", "matplotlib.pyplot", "ultralytics.utils.torch_utils.model_info_for_loggers" ], "chunk_id": "function_on_pretrain_routine_start_3ee0c2b7" }, { "content": "def on_train_epoch_end(trainer) -> None:\n \"\"\"Log training metrics and learning rate at the end of each training epoch.\"\"\"\n _log_scalars(trainer.label_loss_items(trainer.tloss, prefix=\"train\"), trainer.epoch + 1)\n _log_scalars(trainer.lr, trainer.epoch + 1)\n if trainer.epoch == 1:\n _log_images({f.stem: str(f) for f in trainer.save_dir.glob(\"train_batch*.jpg\")}, \"Mosaic\")", "chunk_type": "function", "name": "on_train_epoch_end", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\neptune.py", "start_line": 83, "end_line": 88, "start_col": 0, "end_col": 98, "parent_name": null, "docstring": "Log training metrics and learning rate at the end of each training epoch.", "parameters": [ "trainer" ], "return_type": "None", "decorators": [], "complexity_score": 3, "dependencies": [ "ultralytics.utils.LOGGER", "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "neptune", "neptune.types.File", "matplotlib.image", "matplotlib.pyplot", "ultralytics.utils.torch_utils.model_info_for_loggers" ], "chunk_id": "function_on_train_epoch_end_6ce6fa22" }, { "content": "def on_fit_epoch_end(trainer) -> None:\n \"\"\"Log model info and validation metrics at the end of each fit epoch.\"\"\"\n if run and trainer.epoch == 0:\n from ultralytics.utils.torch_utils import model_info_for_loggers\n\n run[\"Configuration/Model\"] = model_info_for_loggers(trainer)\n _log_scalars(trainer.metrics, trainer.epoch + 1)", "chunk_type": "function", "name": "on_fit_epoch_end", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\neptune.py", "start_line": 91, "end_line": 97, "start_col": 0, "end_col": 52, "parent_name": null, "docstring": "Log model info and validation metrics at the end of each fit epoch.", "parameters": [ "trainer" ], "return_type": "None", "decorators": [], "complexity_score": 2, "dependencies": [ "ultralytics.utils.LOGGER", "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "neptune", "neptune.types.File", "matplotlib.image", "matplotlib.pyplot", "ultralytics.utils.torch_utils.model_info_for_loggers" ], "chunk_id": "function_on_fit_epoch_end_366f762e" }, { "content": "def on_val_end(validator) -> None:\n \"\"\"Log validation images at the end of validation.\"\"\"\n if run:\n # Log val_labels and val_pred\n _log_images({f.stem: str(f) for f in validator.save_dir.glob(\"val*.jpg\")}, \"Validation\")", "chunk_type": "function", "name": "on_val_end", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\neptune.py", "start_line": 100, "end_line": 104, "start_col": 0, "end_col": 96, "parent_name": null, "docstring": "Log validation images at the end of validation.", "parameters": [ "validator" ], "return_type": "None", "decorators": [], "complexity_score": 3, "dependencies": [ "ultralytics.utils.LOGGER", "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "neptune", "neptune.types.File", "matplotlib.image", "matplotlib.pyplot", "ultralytics.utils.torch_utils.model_info_for_loggers" ], "chunk_id": "function_on_val_end_80ca1e49" }, { "content": "def on_train_end(trainer) -> None:\n \"\"\"Log final results, plots, and model weights at the end of training.\"\"\"\n if run:\n # Log final results, CM matrix + PR plots\n files = [\n \"results.png\",\n \"confusion_matrix.png\",\n \"confusion_matrix_normalized.png\",\n *(f\"{x}_curve.png\" for x in (\"F1\", \"PR\", \"P\", \"R\")),\n ]\n files = [(trainer.save_dir / f) for f in files if (trainer.save_dir / f).exists()] # filter\n for f in files:\n _log_plot(title=f.stem, plot_path=f)\n # Log the final model\n run[f\"weights/{trainer.args.name or trainer.args.task}/{trainer.best.name}\"].upload(File(str(trainer.best)))", "chunk_type": "function", "name": "on_train_end", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\neptune.py", "start_line": 107, "end_line": 121, "start_col": 0, "end_col": 116, "parent_name": null, "docstring": "Log final results, plots, and model weights at the end of training.", "parameters": [ "trainer" ], "return_type": "None", "decorators": [], "complexity_score": 5, "dependencies": [ "ultralytics.utils.LOGGER", "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "neptune", "neptune.types.File", "matplotlib.image", "matplotlib.pyplot", "ultralytics.utils.torch_utils.model_info_for_loggers" ], "chunk_id": "function_on_train_end_4d68251b" }, { "content": "callbacks = (\n {\n \"on_pretrain_routine_start\": on_pretrain_routine_start,\n \"on_train_epoch_end\": on_train_epoch_end,\n \"on_fit_epoch_end\": on_fit_epoch_end,\n \"on_val_end\": on_val_end,\n \"on_train_end\": on_train_end,\n }\n if neptune\n else {}\n)", "chunk_type": "variable", "name": "callbacks", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\neptune.py", "start_line": 124, "end_line": 134, "start_col": 0, "end_col": 1, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_callbacks_e915f40c" }, { "content": "from ultralytics.utils import SETTINGS", "chunk_type": "import", "name": "SETTINGS", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\raytune.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 38, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_SETTINGS_4a5b48e5" }, { "content": "def on_fit_epoch_end(trainer):\n \"\"\"\n Report training metrics to Ray Tune at epoch end when a Ray session is active.\n\n Captures metrics from the trainer object and sends them to Ray Tune with the current epoch number,\n enabling hyperparameter tuning optimization. Only executes when within an active Ray Tune session.\n\n Args:\n trainer (ultralytics.engine.trainer.BaseTrainer): The Ultralytics trainer object containing metrics and epochs.\n\n Examples:\n >>> # Called automatically by the Ultralytics training loop\n >>> on_fit_epoch_end(trainer)\n\n References:\n Ray Tune docs: https://docs.ray.io/en/latest/tune/index.html\n \"\"\"\n if ray.train._internal.session.get_session(): # check if Ray Tune session is active\n metrics = trainer.metrics\n session.report({**metrics, **{\"epoch\": trainer.epoch + 1}})", "chunk_type": "function", "name": "on_fit_epoch_end", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\raytune.py", "start_line": 15, "end_line": 34, "start_col": 0, "end_col": 67, "parent_name": null, "docstring": "Report training metrics to Ray Tune at epoch end when a Ray session is active.\n\nCaptures metrics from the trainer object and sends them to Ray Tune with the current epoch number,\nenabling hyperparameter tuning optimization. Only executes when within an active Ray Tune session.\n\nArgs:\n trainer (ultralytics.engine.trainer.BaseTrainer): The Ultralytics trainer object containing metrics and epochs.\n\nExamples:\n >>> # Called automatically by the Ultralytics training loop\n >>> on_fit_epoch_end(trainer)\n\nReferences:\n Ray Tune docs: https://docs.ray.io/en/latest/tune/index.html", "parameters": [ "trainer" ], "return_type": null, "decorators": [], "complexity_score": 2, "dependencies": [ "ultralytics.utils.SETTINGS", "ray", "ray.tune", "ray.air.session" ], "chunk_id": "function_on_fit_epoch_end_2399644c" }, { "content": "callbacks = (\n {\n \"on_fit_epoch_end\": on_fit_epoch_end,\n }\n if tune\n else {}\n)", "chunk_type": "variable", "name": "callbacks", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\raytune.py", "start_line": 37, "end_line": 43, "start_col": 0, "end_col": 1, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_callbacks_ceb3edc8" }, { "content": "from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING, colorstr, torch_utils", "chunk_type": "import", "name": "LOGGER, SETTINGS, TESTS_RUNNING, colorstr, torch_utils", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\tensorboard.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 84, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_LOGGER, SETTINGS, TESTS_RUNNING, colorstr, torch_utils_2bd75001" }, { "content": "def _log_scalars(scalars: dict, step: int = 0) -> None:\n \"\"\"\n Log scalar values to TensorBoard.\n\n Args:\n scalars (dict): Dictionary of scalar values to log to TensorBoard. Keys are scalar names and values are the\n corresponding scalar values.\n step (int): Global step value to record with the scalar values. Used for x-axis in TensorBoard graphs.\n\n Examples:\n Log training metrics\n >>> metrics = {\"loss\": 0.5, \"accuracy\": 0.95}\n >>> _log_scalars(metrics, step=100)\n \"\"\"\n if WRITER:\n for k, v in scalars.items():\n WRITER.add_scalar(k, v, step)", "chunk_type": "function", "name": "_log_scalars", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\tensorboard.py", "start_line": 24, "end_line": 40, "start_col": 0, "end_col": 41, "parent_name": null, "docstring": "Log scalar values to TensorBoard.\n\nArgs:\n scalars (dict): Dictionary of scalar values to log to TensorBoard. Keys are scalar names and values are the\n corresponding scalar values.\n step (int): Global step value to record with the scalar values. Used for x-axis in TensorBoard graphs.\n\nExamples:\n Log training metrics\n >>> metrics = {\"loss\": 0.5, \"accuracy\": 0.95}\n >>> _log_scalars(metrics, step=100)", "parameters": [ "scalars: dict", "step: int" ], "return_type": "None", "decorators": [], "complexity_score": 3, "dependencies": [ "ultralytics.utils.LOGGER", "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.colorstr", "ultralytics.utils.torch_utils", "warnings", "copy.deepcopy", "torch", "torch.utils.tensorboard.SummaryWriter" ], "chunk_id": "function__log_scalars_1970cfe8" }, { "content": "def _log_tensorboard_graph(trainer) -> None:\n \"\"\"\n Log model graph to TensorBoard.\n\n This function attempts to visualize the model architecture in TensorBoard by tracing the model with a dummy input\n tensor. It first tries a simple method suitable for YOLO models, and if that fails, falls back to a more complex\n approach for models like RTDETR that may require special handling.\n\n Args:\n trainer (ultralytics.engine.trainer.BaseTrainer): The trainer object containing the model to visualize.\n Must have attributes model and args with imgsz.\n\n Notes:\n This function requires TensorBoard integration to be enabled and the global WRITER to be initialized.\n It handles potential warnings from the PyTorch JIT tracer and attempts to gracefully handle different\n model architectures.\n \"\"\"\n # Input image\n imgsz = trainer.args.imgsz\n imgsz = (imgsz, imgsz) if isinstance(imgsz, int) else imgsz\n p = next(trainer.model.parameters()) # for device, type\n im = torch.zeros((1, 3, *imgsz), device=p.device, dtype=p.dtype) # input image (must be zeros, not empty)\n\n with warnings.catch_warnings():\n warnings.simplefilter(\"ignore\", category=UserWarning) # suppress jit trace warning\n warnings.simplefilter(\"ignore\", category=torch.jit.TracerWarning) # suppress jit trace warning\n\n # Try simple method first (YOLO)\n try:\n trainer.model.eval() # place in .eval() mode to avoid BatchNorm statistics changes\n WRITER.add_graph(torch.jit.trace(torch_utils.de_parallel(trainer.model), im, strict=False), [])\n LOGGER.info(f\"{PREFIX}model graph visualization added ✅\")\n return\n\n except Exception:\n # Fallback to TorchScript export steps (RTDETR)\n try:\n model = deepcopy(torch_utils.de_parallel(trainer.model))\n model.eval()\n model = model.fuse(verbose=False)\n for m in model.modules():\n if hasattr(m, \"export\"): # Detect, RTDETRDecoder (Segment and Pose use Detect base class)\n m.export = True\n m.format = \"torchscript\"\n model(im) # dry run\n WRITER.add_graph(torch.jit.trace(model, im, strict=False), [])\n LOGGER.info(f\"{PREFIX}model graph visualization added ✅\")\n except Exception as e:\n LOGGER.warning(f\"{PREFIX}TensorBoard graph visualization failure {e}\")", "chunk_type": "function", "name": "_log_tensorboard_graph", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\tensorboard.py", "start_line": 43, "end_line": 91, "start_col": 0, "end_col": 86, "parent_name": null, "docstring": "Log model graph to TensorBoard.\n\nThis function attempts to visualize the model architecture in TensorBoard by tracing the model with a dummy input\ntensor. It first tries a simple method suitable for YOLO models, and if that fails, falls back to a more complex\napproach for models like RTDETR that may require special handling.\n\nArgs:\n trainer (ultralytics.engine.trainer.BaseTrainer): The trainer object containing the model to visualize.\n Must have attributes model and args with imgsz.\n\nNotes:\n This function requires TensorBoard integration to be enabled and the global WRITER to be initialized.\n It handles potential warnings from the PyTorch JIT tracer and attempts to gracefully handle different\n model architectures.", "parameters": [ "trainer" ], "return_type": "None", "decorators": [], "complexity_score": 5, "dependencies": [ "ultralytics.utils.LOGGER", "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.colorstr", "ultralytics.utils.torch_utils", "warnings", "copy.deepcopy", "torch", "torch.utils.tensorboard.SummaryWriter" ], "chunk_id": "function__log_tensorboard_graph_c2b5db29" }, { "content": "def on_pretrain_routine_start(trainer) -> None:\n \"\"\"Initialize TensorBoard logging with SummaryWriter.\"\"\"\n if SummaryWriter:\n try:\n global WRITER\n WRITER = SummaryWriter(str(trainer.save_dir))\n LOGGER.info(f\"{PREFIX}Start with 'tensorboard --logdir {trainer.save_dir}', view at http://localhost:6006/\")\n except Exception as e:\n LOGGER.warning(f\"{PREFIX}TensorBoard not initialized correctly, not logging this run. {e}\")", "chunk_type": "function", "name": "on_pretrain_routine_start", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\tensorboard.py", "start_line": 94, "end_line": 102, "start_col": 0, "end_col": 103, "parent_name": null, "docstring": "Initialize TensorBoard logging with SummaryWriter.", "parameters": [ "trainer" ], "return_type": "None", "decorators": [], "complexity_score": 3, "dependencies": [ "ultralytics.utils.LOGGER", "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.colorstr", "ultralytics.utils.torch_utils", "warnings", "copy.deepcopy", "torch", "torch.utils.tensorboard.SummaryWriter" ], "chunk_id": "function_on_pretrain_routine_start_f2c2a491" }, { "content": "def on_train_start(trainer) -> None:\n \"\"\"Log TensorBoard graph.\"\"\"\n if WRITER:\n _log_tensorboard_graph(trainer)", "chunk_type": "function", "name": "on_train_start", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\tensorboard.py", "start_line": 105, "end_line": 108, "start_col": 0, "end_col": 39, "parent_name": null, "docstring": "Log TensorBoard graph.", "parameters": [ "trainer" ], "return_type": "None", "decorators": [], "complexity_score": 2, "dependencies": [ "ultralytics.utils.LOGGER", "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.colorstr", "ultralytics.utils.torch_utils", "warnings", "copy.deepcopy", "torch", "torch.utils.tensorboard.SummaryWriter" ], "chunk_id": "function_on_train_start_4bde3c8d" }, { "content": "def on_train_epoch_end(trainer) -> None:\n \"\"\"Log scalar statistics at the end of a training epoch.\"\"\"\n _log_scalars(trainer.label_loss_items(trainer.tloss, prefix=\"train\"), trainer.epoch + 1)\n _log_scalars(trainer.lr, trainer.epoch + 1)", "chunk_type": "function", "name": "on_train_epoch_end", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\tensorboard.py", "start_line": 111, "end_line": 114, "start_col": 0, "end_col": 47, "parent_name": null, "docstring": "Log scalar statistics at the end of a training epoch.", "parameters": [ "trainer" ], "return_type": "None", "decorators": [], "complexity_score": 1, "dependencies": [ "ultralytics.utils.LOGGER", "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.colorstr", "ultralytics.utils.torch_utils", "warnings", "copy.deepcopy", "torch", "torch.utils.tensorboard.SummaryWriter" ], "chunk_id": "function_on_train_epoch_end_9d19a031" }, { "content": "def on_fit_epoch_end(trainer) -> None:\n \"\"\"Log epoch metrics at end of training epoch.\"\"\"\n _log_scalars(trainer.metrics, trainer.epoch + 1)", "chunk_type": "function", "name": "on_fit_epoch_end", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\tensorboard.py", "start_line": 117, "end_line": 119, "start_col": 0, "end_col": 52, "parent_name": null, "docstring": "Log epoch metrics at end of training epoch.", "parameters": [ "trainer" ], "return_type": "None", "decorators": [], "complexity_score": 1, "dependencies": [ "ultralytics.utils.LOGGER", "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.colorstr", "ultralytics.utils.torch_utils", "warnings", "copy.deepcopy", "torch", "torch.utils.tensorboard.SummaryWriter" ], "chunk_id": "function_on_fit_epoch_end_1d966ecf" }, { "content": "callbacks = (\n {\n \"on_pretrain_routine_start\": on_pretrain_routine_start,\n \"on_train_start\": on_train_start,\n \"on_fit_epoch_end\": on_fit_epoch_end,\n \"on_train_epoch_end\": on_train_epoch_end,\n }\n if SummaryWriter\n else {}\n)", "chunk_type": "variable", "name": "callbacks", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\tensorboard.py", "start_line": 122, "end_line": 131, "start_col": 0, "end_col": 1, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_callbacks_df9c659b" }, { "content": "from ultralytics.utils import SETTINGS, TESTS_RUNNING", "chunk_type": "import", "name": "SETTINGS, TESTS_RUNNING", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\wb.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 53, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_SETTINGS, TESTS_RUNNING_d2d8e81b" }, { "content": "from ultralytics.utils.torch_utils import model_info_for_loggers", "chunk_type": "import", "name": "model_info_for_loggers", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\wb.py", "start_line": 4, "end_line": 4, "start_col": 0, "end_col": 64, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_model_info_for_loggers_b3f062da" }, { "content": "def _custom_table(x, y, classes, title=\"Precision Recall Curve\", x_title=\"Recall\", y_title=\"Precision\"):\n \"\"\"\n Create and log a custom metric visualization to wandb.plot.pr_curve.\n\n This function crafts a custom metric visualization that mimics the behavior of the default wandb precision-recall\n curve while allowing for enhanced customization. The visual metric is useful for monitoring model performance across\n different classes.\n\n Args:\n x (list): Values for the x-axis; expected to have length N.\n y (list): Corresponding values for the y-axis; also expected to have length N.\n classes (list): Labels identifying the class of each point; length N.\n title (str, optional): Title for the plot.\n x_title (str, optional): Label for the x-axis.\n y_title (str, optional): Label for the y-axis.\n\n Returns:\n (wandb.Object): A wandb object suitable for logging, showcasing the crafted metric visualization.\n \"\"\"\n import pandas # scope for faster 'import ultralytics'\n\n df = pandas.DataFrame({\"class\": classes, \"y\": y, \"x\": x}).round(3)\n fields = {\"x\": \"x\", \"y\": \"y\", \"class\": \"class\"}\n string_fields = {\"title\": title, \"x-axis-title\": x_title, \"y-axis-title\": y_title}\n return wb.plot_table(\n \"wandb/area-under-curve/v0\", wb.Table(dataframe=df), fields=fields, string_fields=string_fields\n )", "chunk_type": "function", "name": "_custom_table", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\wb.py", "start_line": 18, "end_line": 44, "start_col": 0, "end_col": 5, "parent_name": null, "docstring": "Create and log a custom metric visualization to wandb.plot.pr_curve.\n\nThis function crafts a custom metric visualization that mimics the behavior of the default wandb precision-recall\ncurve while allowing for enhanced customization. The visual metric is useful for monitoring model performance across\ndifferent classes.\n\nArgs:\n x (list): Values for the x-axis; expected to have length N.\n y (list): Corresponding values for the y-axis; also expected to have length N.\n classes (list): Labels identifying the class of each point; length N.\n title (str, optional): Title for the plot.\n x_title (str, optional): Label for the x-axis.\n y_title (str, optional): Label for the y-axis.\n\nReturns:\n (wandb.Object): A wandb object suitable for logging, showcasing the crafted metric visualization.", "parameters": [ "x", "y", "classes", "title", "x_title", "y_title" ], "return_type": null, "decorators": [], "complexity_score": 1, "dependencies": [ "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.torch_utils.model_info_for_loggers", "wandb", "pandas", "numpy" ], "chunk_id": "function__custom_table_535745f6" }, { "content": "def _plot_curve(\n x,\n y,\n names=None,\n id=\"precision-recall\",\n title=\"Precision Recall Curve\",\n x_title=\"Recall\",\n y_title=\"Precision\",\n num_x=100,\n only_mean=False,\n):\n \"\"\"\n Log a metric curve visualization.\n\n This function generates a metric curve based on input data and logs the visualization to wandb.\n The curve can represent aggregated data (mean) or individual class data, depending on the 'only_mean' flag.\n\n Args:\n x (np.ndarray): Data points for the x-axis with length N.\n y (np.ndarray): Corresponding data points for the y-axis with shape (C, N), where C is the number of classes.\n names (list, optional): Names of the classes corresponding to the y-axis data; length C.\n id (str, optional): Unique identifier for the logged data in wandb.\n title (str, optional): Title for the visualization plot.\n x_title (str, optional): Label for the x-axis.\n y_title (str, optional): Label for the y-axis.\n num_x (int, optional): Number of interpolated data points for visualization.\n only_mean (bool, optional): Flag to indicate if only the mean curve should be plotted.\n\n Notes:\n The function leverages the '_custom_table' function to generate the actual visualization.\n \"\"\"\n import numpy as np\n\n # Create new x\n if names is None:\n names = []\n x_new = np.linspace(x[0], x[-1], num_x).round(5)\n\n # Create arrays for logging\n x_log = x_new.tolist()\n y_log = np.interp(x_new, x, np.mean(y, axis=0)).round(3).tolist()\n\n if only_mean:\n table = wb.Table(data=list(zip(x_log, y_log)), columns=[x_title, y_title])\n wb.run.log({title: wb.plot.line(table, x_title, y_title, title=title)})\n else:\n classes = [\"mean\"] * len(x_log)\n for i, yi in enumerate(y):\n x_log.extend(x_new) # add new x\n y_log.extend(np.interp(x_new, x, yi)) # interpolate y to new x\n classes.extend([names[i]] * len(x_new)) # add class names\n wb.log({id: _custom_table(x_log, y_log, classes, title, x_title, y_title)}, commit=False)", "chunk_type": "function", "name": "_plot_curve", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\wb.py", "start_line": 47, "end_line": 98, "start_col": 0, "end_col": 97, "parent_name": null, "docstring": "Log a metric curve visualization.\n\nThis function generates a metric curve based on input data and logs the visualization to wandb.\nThe curve can represent aggregated data (mean) or individual class data, depending on the 'only_mean' flag.\n\nArgs:\n x (np.ndarray): Data points for the x-axis with length N.\n y (np.ndarray): Corresponding data points for the y-axis with shape (C, N), where C is the number of classes.\n names (list, optional): Names of the classes corresponding to the y-axis data; length C.\n id (str, optional): Unique identifier for the logged data in wandb.\n title (str, optional): Title for the visualization plot.\n x_title (str, optional): Label for the x-axis.\n y_title (str, optional): Label for the y-axis.\n num_x (int, optional): Number of interpolated data points for visualization.\n only_mean (bool, optional): Flag to indicate if only the mean curve should be plotted.\n\nNotes:\n The function leverages the '_custom_table' function to generate the actual visualization.", "parameters": [ "x", "y", "names", "id", "title", "x_title", "y_title", "num_x", "only_mean" ], "return_type": null, "decorators": [], "complexity_score": 4, "dependencies": [ "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.torch_utils.model_info_for_loggers", "wandb", "pandas", "numpy" ], "chunk_id": "function__plot_curve_7bb47a16" }, { "content": "def _log_plots(plots, step):\n \"\"\"\n Log plots to WandB at a specific step if they haven't been logged already.\n\n This function checks each plot in the input dictionary against previously processed plots and logs\n new or updated plots to WandB at the specified step.\n\n Args:\n plots (dict): Dictionary of plots to log, where keys are plot names and values are dictionaries\n containing plot metadata including timestamps.\n step (int): The step/epoch at which to log the plots in the WandB run.\n\n Notes:\n The function uses a shallow copy of the plots dictionary to prevent modification during iteration.\n Plots are identified by their stem name (filename without extension).\n Each plot is logged as a WandB Image object.\n \"\"\"\n for name, params in plots.copy().items(): # shallow copy to prevent plots dict changing during iteration\n timestamp = params[\"timestamp\"]\n if _processed_plots.get(name) != timestamp:\n wb.run.log({name.stem: wb.Image(str(name))}, step=step)\n _processed_plots[name] = timestamp", "chunk_type": "function", "name": "_log_plots", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\wb.py", "start_line": 101, "end_line": 122, "start_col": 0, "end_col": 46, "parent_name": null, "docstring": "Log plots to WandB at a specific step if they haven't been logged already.\n\nThis function checks each plot in the input dictionary against previously processed plots and logs\nnew or updated plots to WandB at the specified step.\n\nArgs:\n plots (dict): Dictionary of plots to log, where keys are plot names and values are dictionaries\n containing plot metadata including timestamps.\n step (int): The step/epoch at which to log the plots in the WandB run.\n\nNotes:\n The function uses a shallow copy of the plots dictionary to prevent modification during iteration.\n Plots are identified by their stem name (filename without extension).\n Each plot is logged as a WandB Image object.", "parameters": [ "plots", "step" ], "return_type": null, "decorators": [], "complexity_score": 3, "dependencies": [ "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.torch_utils.model_info_for_loggers", "wandb", "pandas", "numpy" ], "chunk_id": "function__log_plots_6342a3c7" }, { "content": "def on_pretrain_routine_start(trainer):\n \"\"\"Initialize and start wandb project if module is present.\"\"\"\n if not wb.run:\n wb.init(\n project=str(trainer.args.project).replace(\"/\", \"-\") if trainer.args.project else \"Ultralytics\",\n name=str(trainer.args.name).replace(\"/\", \"-\"),\n config=vars(trainer.args),\n )", "chunk_type": "function", "name": "on_pretrain_routine_start", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\wb.py", "start_line": 125, "end_line": 132, "start_col": 0, "end_col": 9, "parent_name": null, "docstring": "Initialize and start wandb project if module is present.", "parameters": [ "trainer" ], "return_type": null, "decorators": [], "complexity_score": 2, "dependencies": [ "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.torch_utils.model_info_for_loggers", "wandb", "pandas", "numpy" ], "chunk_id": "function_on_pretrain_routine_start_e7123d05" }, { "content": "def on_fit_epoch_end(trainer):\n \"\"\"Log training metrics and model information at the end of an epoch.\"\"\"\n wb.run.log(trainer.metrics, step=trainer.epoch + 1)\n _log_plots(trainer.plots, step=trainer.epoch + 1)\n _log_plots(trainer.validator.plots, step=trainer.epoch + 1)\n if trainer.epoch == 0:\n wb.run.log(model_info_for_loggers(trainer), step=trainer.epoch + 1)", "chunk_type": "function", "name": "on_fit_epoch_end", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\wb.py", "start_line": 135, "end_line": 141, "start_col": 0, "end_col": 75, "parent_name": null, "docstring": "Log training metrics and model information at the end of an epoch.", "parameters": [ "trainer" ], "return_type": null, "decorators": [], "complexity_score": 2, "dependencies": [ "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.torch_utils.model_info_for_loggers", "wandb", "pandas", "numpy" ], "chunk_id": "function_on_fit_epoch_end_9fd9944a" }, { "content": "def on_train_epoch_end(trainer):\n \"\"\"Log metrics and save images at the end of each training epoch.\"\"\"\n wb.run.log(trainer.label_loss_items(trainer.tloss, prefix=\"train\"), step=trainer.epoch + 1)\n wb.run.log(trainer.lr, step=trainer.epoch + 1)\n if trainer.epoch == 1:\n _log_plots(trainer.plots, step=trainer.epoch + 1)", "chunk_type": "function", "name": "on_train_epoch_end", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\wb.py", "start_line": 144, "end_line": 149, "start_col": 0, "end_col": 57, "parent_name": null, "docstring": "Log metrics and save images at the end of each training epoch.", "parameters": [ "trainer" ], "return_type": null, "decorators": [], "complexity_score": 2, "dependencies": [ "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.torch_utils.model_info_for_loggers", "wandb", "pandas", "numpy" ], "chunk_id": "function_on_train_epoch_end_61965413" }, { "content": "def on_train_end(trainer):\n \"\"\"Save the best model as an artifact and log final plots at the end of training.\"\"\"\n _log_plots(trainer.validator.plots, step=trainer.epoch + 1)\n _log_plots(trainer.plots, step=trainer.epoch + 1)\n art = wb.Artifact(type=\"model\", name=f\"run_{wb.run.id}_model\")\n if trainer.best.exists():\n art.add_file(trainer.best)\n wb.run.log_artifact(art, aliases=[\"best\"])\n # Check if we actually have plots to save\n if trainer.args.plots and hasattr(trainer.validator.metrics, \"curves_results\"):\n for curve_name, curve_values in zip(trainer.validator.metrics.curves, trainer.validator.metrics.curves_results):\n x, y, x_title, y_title = curve_values\n _plot_curve(\n x,\n y,\n names=list(trainer.validator.metrics.names.values()),\n id=f\"curves/{curve_name}\",\n title=curve_name,\n x_title=x_title,\n y_title=y_title,\n )\n wb.run.finish() # required or run continues on dashboard", "chunk_type": "function", "name": "on_train_end", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\wb.py", "start_line": 152, "end_line": 173, "start_col": 0, "end_col": 19, "parent_name": null, "docstring": "Save the best model as an artifact and log final plots at the end of training.", "parameters": [ "trainer" ], "return_type": null, "decorators": [], "complexity_score": 4, "dependencies": [ "ultralytics.utils.SETTINGS", "ultralytics.utils.TESTS_RUNNING", "ultralytics.utils.torch_utils.model_info_for_loggers", "wandb", "pandas", "numpy" ], "chunk_id": "function_on_train_end_2b067edb" }, { "content": "callbacks = (\n {\n \"on_pretrain_routine_start\": on_pretrain_routine_start,\n \"on_train_epoch_end\": on_train_epoch_end,\n \"on_fit_epoch_end\": on_fit_epoch_end,\n \"on_train_end\": on_train_end,\n }\n if wb\n else {}\n)", "chunk_type": "variable", "name": "callbacks", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\wb.py", "start_line": 176, "end_line": 185, "start_col": 0, "end_col": 1, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable_callbacks_0344b950" }, { "content": "from .base import add_integration_callbacks, default_callbacks, get_default_callbacks", "chunk_type": "import", "name": "add_integration_callbacks, default_callbacks, get_default_callbacks", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\__init__.py", "start_line": 3, "end_line": 3, "start_col": 0, "end_col": 85, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "import_add_integration_callbacks, default_callbacks, get_default_callbacks_ba9532b6" }, { "content": "__all__ = \"add_integration_callbacks\", \"default_callbacks\", \"get_default_callbacks\"", "chunk_type": "variable", "name": "__all__", "file_path": "ultralytics\\ultralytics\\utils\\callbacks\\__init__.py", "start_line": 5, "end_line": 5, "start_col": 0, "end_col": 83, "parent_name": null, "docstring": null, "parameters": null, "return_type": null, "decorators": null, "complexity_score": null, "dependencies": null, "chunk_id": "variable___all___c2d15a05" } ]